Source code for jaxfluids.levelset.levelset_creator

#*------------------------------------------------------------------------------*
#* JAX-FLUIDS -                                                                 *
#*                                                                              *
#* A fully-differentiable CFD solver for compressible two-phase flows.          *
#* Copyright (C) 2022  Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams    *
#*                                                                              *
#* This program is free software: you can redistribute it and/or modify         *
#* it under the terms of the GNU General Public License as published by         *
#* the Free Software Foundation, either version 3 of the License, or            *
#* (at your option) any later version.                                          *
#*                                                                              *
#* This program is distributed in the hope that it will be useful,              *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of               *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                *
#* GNU General Public License for more details.                                 *
#*                                                                              *
#* You should have received a copy of the GNU General Public License            *
#* along with this program.  If not, see <https://www.gnu.org/licenses/>.       *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* CONTACT                                                                      *
#*                                                                              *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de        *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* Munich, April 15th, 2022                                                     *
#*                                                                              *
#*------------------------------------------------------------------------------*

from typing import List, Union, Dict
import types

import jax.numpy as jnp

from jaxfluids.domain_information import DomainInformation
from jaxfluids.unit_handler import UnitHandler

[docs] class LevelsetCreator: """The LevelsetCreator implements functionality to create initial levelset fields. The initial levelset field in one of two ways: 1) Lambda function via case setup file 2) List of building blocks. A single building block includes a shape and a lambda function for the bounding domain. """ def __init__(self, domain_information: DomainInformation, unit_handler: UnitHandler, initial_levelset: Union[str, List], narrow_band_cutoff: int) -> None: self.unit_handler = unit_handler self.initial_levelset = initial_levelset self.narrow_band_cutoff = narrow_band_cutoff self.cell_centers = domain_information.cell_centers self.cell_sizes = domain_information.cell_sizes self.nx, self.ny, self.nz = domain_information.number_of_cells self.nhx, self.nhy, self.nhz = domain_information.domain_slices_conservatives self.nh = domain_information.nh_conservatives self.active_axis_indices = domain_information.active_axis_indices self.smallest_cell_size = jnp.min(jnp.array([self.cell_sizes[i] for i in self.active_axis_indices])) self.shape_function_dict: Dict[str, types.LambdaType] = { "circle": self.get_circle, "square": self.get_rectangle, "rounded_square": self.get_rectangle, "square": self.get_rectangle, "rounded_rectangle": self.get_rectangle, "sphere": self.get_sphere }
[docs] def get_circle(self, radius: float, position: List) -> jnp.ndarray: """Creates the levelset field for a circle. :param radius: Radius :type radius: float :param position: Center position :type position: List :return: Levelset buffer :rtype: jnp.ndarray """ mesh_grid = [jnp.meshgrid(*self.cell_centers, indexing="ij")[i] for i in self.active_axis_indices] for i in range(len(mesh_grid)): mesh_grid[i] = self.unit_handler.dimensionalize(mesh_grid[i], "length") levelset = - radius + jnp.sqrt((mesh_grid[0] - position[0])**2 + (mesh_grid[1] - position[1])**2) return levelset
[docs] def get_rectangle(self, length: float, position: List, height : float = None, radius: float = None) -> jnp.ndarray: """Creates the levelset field for a rectangle. If the radius argument is specified, the rectangle corners will be rounded using that radius. If the height argument is not specified, a square will be created. :param length: Length :type length: float :param position: Center position :type position: List :param height: Height, defaults to None :type height: float, optional :param radius: Radius of the corners, defaults to None :type radius: float, optional :return: Leveset buffer :rtype: jnp.ndarray """ mesh_grid = [jnp.meshgrid(*self.cell_centers, indexing="ij")[i] for i in self.active_axis_indices] for i in range(len(mesh_grid)): mesh_grid[i] = self.unit_handler.dimensionalize(mesh_grid[i], "length") if height == None: height = length edge_tr = jnp.array([position[0] + length/2.0, position[1] + height/2.0]) edge_tl = jnp.array([position[0] - length/2.0, position[1] + height/2.0]) edge_bl = jnp.array([position[0] - length/2.0, position[1] - height/2.0]) edge_br = jnp.array([position[0] + length/2.0, position[1] - height/2.0]) line_1_slope = (edge_tr[1] - edge_bl[1])/(edge_tr[0] - edge_bl[0]) line_1_offset = edge_tr[1] - line_1_slope * edge_tr[0] line_2_slope = (edge_br[1] - edge_tl[1])/(edge_br[0] - edge_tl[0]) line_2_offset = edge_br[1] - line_2_slope * edge_br[0] levelset = (mesh_grid[1] - (position[1] + height/2.0)) * ((mesh_grid[1] > line_1_slope * mesh_grid[0] + line_1_offset) & (mesh_grid[1] > line_2_slope * mesh_grid[0] + line_2_offset)) levelset += - (mesh_grid[1] - (position[1] - height/2.0)) * ((mesh_grid[1] < line_1_slope * mesh_grid[0] + line_1_offset) & (mesh_grid[1] < line_2_slope * mesh_grid[0] + line_2_offset)) levelset += (mesh_grid[0] - (position[0] + length/2.0)) * ((mesh_grid[1] <= line_1_slope * mesh_grid[0] + line_1_offset) & (mesh_grid[1] >= line_2_slope * mesh_grid[0] + line_2_offset)) levelset += - (mesh_grid[0] - (position[0] - length/2.0)) * ((mesh_grid[1] >= line_1_slope * mesh_grid[0] + line_1_offset) & (mesh_grid[1] <= line_2_slope * mesh_grid[0] + line_2_offset)) if radius: levelset *= jnp.invert((mesh_grid[0] > edge_tr[0] - radius) & (mesh_grid[1] > edge_tr[1] - radius)) levelset *= jnp.invert((mesh_grid[0] < edge_tl[0] + radius) & (mesh_grid[1] > edge_tl[1] - radius)) levelset *= jnp.invert((mesh_grid[0] < edge_bl[0] + radius) & (mesh_grid[1] < edge_bl[1] + radius)) levelset *= jnp.invert((mesh_grid[0] > edge_br[0] - radius) & (mesh_grid[1] < edge_br[1] + radius)) levelset += self.get_circle(radius, edge_tr + jnp.array([-1.0, -1.0]) * radius) * ((mesh_grid[0] > edge_tr[0] - radius) & (mesh_grid[1] > edge_tr[1] - radius)) levelset += self.get_circle(radius, edge_tl + jnp.array([ 1.0, -1.0]) * radius) * ((mesh_grid[0] < edge_tl[0] + radius) & (mesh_grid[1] > edge_tl[1] - radius)) levelset += self.get_circle(radius, edge_bl + jnp.array([ 1.0, 1.0]) * radius) * ((mesh_grid[0] < edge_bl[0] + radius) & (mesh_grid[1] < edge_bl[1] + radius)) levelset += self.get_circle(radius, edge_br + jnp.array([-1.0, 1.0]) * radius) * ((mesh_grid[0] > edge_br[0] - radius) & (mesh_grid[1] < edge_br[1] + radius)) return levelset
[docs] def get_sphere(self, radius: float, position: float) -> jnp.ndarray: """Creates the levelset field for a sphere. :param radius: Radius :type radius: float :param position: Center position :type position: float :return: _description_ :rtype: jnp.ndarray """ mesh_grid = [jnp.meshgrid(*self.cell_centers, indexing="ij")[i] for i in self.active_axis_indices] for i in range(len(mesh_grid)): mesh_grid[i] = self.unit_handler.dimensionalize(mesh_grid[i], "length") levelset = - radius + jnp.sqrt(sum([(mesh_grid[i] - position[i])**2 for i in range(3)])) return levelset
[docs] def create_levelset(self) -> jnp.ndarray: """Creates the levelset field either from the user defined lambda or from the user defined building blocks. :return: Levelset buffer :rtype: jnp.ndarray """ # CREATE BUFFER levelset_cutoff = self.narrow_band_cutoff * self.smallest_cell_size levelset_buffer = levelset_cutoff*jnp.ones((self.nx + 2*self.nh if self.nx > 1 else self.nx, self.ny + 2*self.nh if self.ny > 1 else self.ny, self.nz + 2*self.nh if self.nz > 1 else self.nz)) # INPUT FOR LAMBDAS mesh_grid = [jnp.meshgrid(*self.cell_centers, indexing="ij")[i] for i in self.active_axis_indices] for i in range(len(mesh_grid)): mesh_grid[i] = self.unit_handler.dimensionalize(mesh_grid[i], "length") # FROM LAMBDA FUNCTION if type(self.initial_levelset) == types.LambdaType: levelset = self.initial_levelset(*mesh_grid) levelset = self.unit_handler.non_dimensionalize(levelset, "length") levelset_buffer = levelset_buffer.at[self.nhx, self.nhy, self.nhz].set(levelset) # FROM BUILDING BLOCKS else: for levelset_object in self.initial_levelset: levelset = self.shape_function_dict[levelset_object["shape"]](**levelset_object["parameters"]) levelset = self.unit_handler.non_dimensionalize(levelset, "length") bounding_domain = levelset_object["bounding_domain"] mask = bounding_domain(*mesh_grid) levelset_buffer = levelset_buffer.at[self.nhx, self.nhy, self.nhz].mul(1.0 - mask) levelset_buffer = levelset_buffer.at[self.nhx, self.nhy, self.nhz].add(levelset * mask) return levelset_buffer