#*------------------------------------------------------------------------------*
#* JAX-FLUIDS - *
#* *
#* A fully-differentiable CFD solver for compressible two-phase flows. *
#* Copyright (C) 2022 Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams *
#* *
#* This program is free software: you can redistribute it and/or modify *
#* it under the terms of the GNU General Public License as published by *
#* the Free Software Foundation, either version 3 of the License, or *
#* (at your option) any later version. *
#* *
#* This program is distributed in the hope that it will be useful, *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
#* GNU General Public License for more details. *
#* *
#* You should have received a copy of the GNU General Public License *
#* along with this program. If not, see <https://www.gnu.org/licenses/>. *
#* *
#*------------------------------------------------------------------------------*
#* *
#* CONTACT *
#* *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de *
#* *
#*------------------------------------------------------------------------------*
#* *
#* Munich, April 15th, 2022 *
#* *
#*------------------------------------------------------------------------------*
import types
from typing import Callable, Union, Dict, List, Tuple
import jax.numpy as jnp
import numpy as np
from jaxfluids.domain_information import DomainInformation
from jaxfluids.materials.material_manager import MaterialManager
from jaxfluids.unit_handler import UnitHandler
from jaxfluids.utilities import get_conservatives_from_primitives
[docs]
class BoundaryCondition:
""" The BoundaryCondition class implements functionality to enforce user-
specified boundary conditions. Boundary conditions are enforced on the
primitive variables and the level-set field (for two-phase simulations
only).
Boundary conditions for the primitive variables:
1) Periodic
2) Symmetric
3) No-slip walls
4) Dirichlet
5) Neumann
Boundary conditions for the level-set field:
1) Periodic
2) Symmetry
3) Zero-Gradient
"""
def __init__(self, domain_information: DomainInformation, material_manager: MaterialManager, unit_handler: UnitHandler,
boundary_types: Dict, wall_velocity_functions: Dict, dirichlet_functions: Dict, neumann_functions: Dict, levelset_type: str) -> None:
self.material_manager = material_manager
self.unit_handler = unit_handler
self.levelset_type = levelset_type
if self.levelset_type != None:
self.boundary_types_primes = boundary_types["primes"]
self.boundary_types_levelset = boundary_types["levelset"]
else:
self.boundary_types_primes = boundary_types
self.wall_velocity_functions = wall_velocity_functions
self.dirichlet_functions = dirichlet_functions
self.neumann_functions = neumann_functions
self.dim = domain_information.dim
self.number_of_cells = np.array(domain_information.number_of_cells)
self.nh = domain_information.nh_conservatives
self.cell_sizes = { "east" : domain_information.cell_sizes[0], "west" : domain_information.cell_sizes[0],
"north" : domain_information.cell_sizes[1], "south" : domain_information.cell_sizes[1],
"top" : domain_information.cell_sizes[2], "bottom" : domain_information.cell_sizes[2] }
self.coordinates_plane = {
"east" : { "y": domain_information.cell_centers[1], "z": domain_information.cell_centers[2]},
"west" : { "y": domain_information.cell_centers[1], "z": domain_information.cell_centers[2]},
"north" : { "x": domain_information.cell_centers[0], "z": domain_information.cell_centers[2]},
"south" : { "x": domain_information.cell_centers[0], "z": domain_information.cell_centers[2]},
"top" : { "x": domain_information.cell_centers[0], "y": domain_information.cell_centers[1]},
"bottom": { "x": domain_information.cell_centers[0], "y": domain_information.cell_centers[1]}
}
self.inactive_axis = domain_information.inactive_axis
self.active_axis = domain_information.active_axis
self.location_to_axis = { "east" : "x", "west" : "x", "north" : "y", "south" : "y", "top" : "z", "bottom": "z" }
self.spatial_axis_to_index = {"x": 0, "y": 1, "z": 2}
self.spatial_axis_to_index_for_slices = {"x": -3, "y": -2, "z": -1}
# SLICE OBJECTS
nh = domain_information.nh_conservatives
nhx, nhy, nhz = domain_information.domain_slices_conservatives
self.slices_fill = {
"east" : jnp.s_[..., -nh:, nhy, nhz],
"west" : jnp.s_[..., :nh, nhy, nhz],
"north" : jnp.s_[..., nhx, -nh:, nhz],
"south" : jnp.s_[..., nhx, :nh, nhz],
"top" : jnp.s_[..., nhx, nhy, -nh:],
"bottom": jnp.s_[..., nhx, nhy, :nh],
}
self.slices_retrieve = {
"periodic" : {
"east" : jnp.s_[..., nh:2*nh, nhy, nhz],
"west" : jnp.s_[..., -2*nh:-nh, nhy, nhz],
"north" : jnp.s_[..., nhx, nh:2*nh, nhz],
"south" : jnp.s_[..., nhx, -2*nh:-nh, nhz],
"top" : jnp.s_[..., nhx, nhy, nh:2*nh],
"bottom" : jnp.s_[..., nhx, nhy, -2*nh:-nh],
},
"symmetry" : {
"east" : jnp.s_[..., -nh-1:-2*nh-1:-1, nhy, nhz],
"west" : jnp.s_[..., 2*nh-1:nh-1:-1, nhy, nhz],
"north" : jnp.s_[..., nhx, -nh-1:-2*nh-1:-1, nhz],
"south" : jnp.s_[..., nhx, 2*nh-1:nh-1:-1, nhz],
"top" : jnp.s_[..., nhx, nhy, -nh-1:-2*nh-1:-1],
"bottom" : jnp.s_[..., nhx, nhy, 2*nh-1:nh-1:-1],
},
"neumann" : {
"east" : jnp.s_[..., -nh-1:-nh, nhy, nhz],
"west" : jnp.s_[..., nh:nh+1, nhy, nhz],
"north" : jnp.s_[..., nhx, -nh-1:-nh, nhz],
"south" : jnp.s_[..., nhx, nh:nh+1, nhz],
"top" : jnp.s_[..., nhx, nhy, -nh-1:-nh],
"bottom" : jnp.s_[..., nhx, nhy, nh:nh+1],
},
}
# MEMBER FOR SYMMETRY
self.symmetry_indices = { "x": ([0,2,3,4], 1), "y": ([0,1,3,4], 2), "z": ([0,1,2,4], 3) }
# MEMBER FOR NEUMANN
self.upwind_difference_sign = { "east" : -1, "west" : 1, "north" : -1, "south" : 1, "top" : -1, "bottom": 1 }
# BOUNDARY TYPES IN CORNERS
self.corners = self.assign_corners()
# CORNER FILL SLICES
self.corner_slices_fill = {
"west_south" : jnp.s_[..., :nh, :nh, nhz],
"west_north" : jnp.s_[..., :nh, -nh:, nhz],
"east_south" : jnp.s_[..., -nh:, :nh, nhz],
"east_north" : jnp.s_[..., -nh:, -nh:, nhz],
"bottom_south" : jnp.s_[..., nhx, :nh, :nh],
"bottom_north" : jnp.s_[..., nhx, -nh:, :nh],
"top_south" : jnp.s_[..., nhx, :nh, -nh:],
"top_north" : jnp.s_[..., nhx, -nh:, -nh:],
"bottom_east" : jnp.s_[..., -nh:, nhy, :nh],
"bottom_west" : jnp.s_[..., :nh, nhy, :nh],
"top_east" : jnp.s_[..., -nh:, nhy, -nh:],
"top_west" : jnp.s_[..., :nh, nhy, -nh:]
}
# CORNER RETRIEVE SLICES - NUMBERING CLOCKWISE FROM FILL BLOCK ALONG CORRESPONDING AXIS
self.corner_slices_retrieve = {
# WEST EAST NORTH SOUTH RETRIEVES
"west_south_0" : jnp.s_[..., :nh, nh:2*nh, nhz],
"west_south_1" : jnp.s_[..., nh:2*nh, nh:2*nh, nhz],
"west_south_2" : jnp.s_[..., nh:2*nh, :nh, nhz],
"west_north_0" : jnp.s_[..., nh:2*nh, -nh:, nhz],
"west_north_1" : jnp.s_[..., nh:2*nh, -2*nh:-nh, nhz],
"west_north_2" : jnp.s_[..., :nh, -2*nh:-nh, nhz],
"east_south_0" : jnp.s_[..., -2*nh:-nh, :nh, nhz],
"east_south_1" : jnp.s_[..., -2*nh:-nh, nh:2*nh, nhz],
"east_south_2" : jnp.s_[..., -nh:, nh:2*nh, nhz],
"east_north_0" : jnp.s_[..., -nh:, -2*nh:-nh, nhz],
"east_north_1" : jnp.s_[..., -2*nh:-nh, -2*nh:-nh, nhz],
"east_north_2" : jnp.s_[..., -2*nh:-nh, -nh:, nhz],
# BOTTOM TOP NORTH SOUTH RETRIEVES
"bottom_south_0" : jnp.s_[..., nhx, nh:2*nh, :nh],
"bottom_south_1" : jnp.s_[..., nhx, nh:2*nh, nh:2*nh],
"bottom_south_2" : jnp.s_[..., nhx, :nh, nh:2*nh],
"bottom_north_0" : jnp.s_[..., nhx, -nh:, nh:2*nh],
"bottom_north_1" : jnp.s_[..., nhx, -2*nh:-nh, nh:2*nh],
"bottom_north_2" : jnp.s_[..., nhx, -2*nh:-nh, :nh],
"top_south_0" : jnp.s_[..., nhx, :nh, -2*nh:-nh],
"top_south_1" : jnp.s_[..., nhx, nh:2*nh, -2*nh:-nh],
"top_south_2" : jnp.s_[..., nhx, nh:2*nh, -nh:],
"top_north_0" : jnp.s_[..., nhx, -2*nh:-nh, -nh:],
"top_north_1" : jnp.s_[..., nhx, -2*nh:-nh, -2*nh:-nh],
"top_north_2" : jnp.s_[..., nhx, -nh:, -2*nh:-nh],
# BOTTOM TOP WEST EAST RETRIEVES
"bottom_west_0" : jnp.s_[..., :nh, nhy, nh:2*nh],
"bottom_west_1" : jnp.s_[..., nh:2*nh, nhy, nh:2*nh],
"bottom_west_2" : jnp.s_[..., nh:2*nh, nhy, :nh],
"bottom_east_0" : jnp.s_[..., -2*nh:-nh, nhy, :nh],
"bottom_east_1" : jnp.s_[..., -2*nh:-nh, nhy, nh:2*nh],
"bottom_east_2" : jnp.s_[..., -nh:, nhy, nh:2*nh],
"top_east_0" : jnp.s_[..., -nh:, nhy, -2*nh:-nh],
"top_east_1" : jnp.s_[..., -2*nh:-nh, nhy, -2*nh:-nh],
"top_east_2" : jnp.s_[..., -2*nh:-nh, nhy, -nh:],
"top_west_0" : jnp.s_[..., nh:2*nh, nhy, -nh:],
"top_west_1" : jnp.s_[..., nh:2*nh, nhy, -2*nh:-nh],
"top_west_2" : jnp.s_[..., :nh, nhy, -2*nh:-nh]
}
# CORNER COMBINATIONS
self.corner_combinations = {
# WEST EAST NORTH SOUTH COMBS
"west_south": {
"periodic_periodic": ( "east_north_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "west_south_1", [-1,-1, 1] ),
"symmetry_any": ( "west_south_2", [-1, 1, 1] ),
"any_symmetry": ( "west_south_0", [ 1,-1, 1] ),
"periodic_any": ( "east_south_0", [ 1, 1, 1] ),
"any_periodic": ( "west_north_2", [ 1, 1, 1] ),
"any_any": [("west_south_0", [ 1, 1, 1]), ("west_south_2", [ 1, 1, 1])],
},
"west_north": {
"periodic_periodic": ( "east_south_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "west_north_1", [-1,-1, 1] ),
"symmetry_any": ( "west_north_0", [-1, 1, 1] ),
"any_symmetry": ( "west_north_2", [ 1,-1, 1] ),
"periodic_any": ( "east_north_2", [ 1, 1, 1] ),
"any_periodic": ( "west_south_0", [ 1, 1, 1] ),
"any_any": [("west_north_0", [ 1, 1, 1]), ("west_north_2", [ 1, 1, 1])],
},
"east_north": {
"periodic_periodic": ( "west_south_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "east_north_1", [-1,-1, 1] ),
"symmetry_any": ( "east_north_2", [-1, 1, 1] ),
"any_symmetry": ( "east_north_0", [ 1,-1, 1] ),
"periodic_any": ( "west_north_0", [ 1, 1, 1] ),
"any_periodic": ( "east_south_2", [ 1, 1, 1] ),
"any_any": [("east_north_0", [ 1, 1, 1]), ("east_north_2", [ 1, 1, 1])],
},
"east_south": {
"periodic_periodic": ( "west_north_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "east_south_1", [-1,-1, 1] ),
"symmetry_any": ( "east_south_0", [-1, 1, 1] ),
"any_symmetry": ( "east_south_2", [ 1,-1, 1] ),
"periodic_any": ( "west_south_2", [ 1, 1, 1] ),
"any_periodic": ( "east_north_0", [ 1, 1, 1] ),
"any_any": [("east_south_0", [ 1, 1, 1]), ("east_south_2", [ 1, 1, 1])],
},
# BOTTOM TOP NORTH SOUTH COMBS
"bottom_south": {
"periodic_periodic": ( "top_north_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "bottom_south_1", [ 1,-1,-1] ),
"symmetry_any": ( "bottom_south_2", [ 1, 1,-1] ),
"any_symmetry": ( "bottom_south_0", [ 1,-1, 1] ),
"periodic_any": ( "top_south_0", [ 1, 1, 1] ),
"any_periodic": ( "bottom_north_2", [ 1, 1, 1] ),
"any_any": [ ("bottom_south_0", [ 1, 1, 1]), ("bottom_south_2", [ 1, 1, 1]) ],
},
"bottom_north": {
"periodic_periodic": ( "top_south_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "bottom_north_1", [ 1,-1,-1] ),
"symmetry_any": ( "bottom_north_0", [ 1, 1,-1] ),
"any_symmetry": ( "bottom_north_2", [ 1,-1, 1] ),
"periodic_any": ( "top_north_2", [ 1, 1, 1] ),
"any_periodic": ( "bottom_south_0", [ 1, 1, 1] ),
"any_any": [ ("bottom_north_0", [ 1, 1, 1]), ("bottom_north_2", [ 1, 1, 1]) ],
},
"top_north": {
"periodic_periodic": ( "bottom_south_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "top_north_1", [ 1,-1,-1] ),
"symmetry_any": ( "top_north_2", [ 1, 1,-1] ),
"any_symmetry": ( "top_north_0", [ 1,-1, 1] ),
"periodic_any": ( "bottom_north_0", [ 1, 1, 1] ),
"any_periodic": ( "top_south_2", [ 1, 1, 1] ),
"any_any": [ ("top_north_0", [ 1, 1, 1]), ("top_north_2", [ 1, 1, 1]) ],
},
"top_south": {
"periodic_periodic": ( "bottom_north_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "top_south_1", [ 1,-1,-1] ),
"symmetry_any": ( "top_south_0", [ 1, 1,-1] ),
"any_symmetry": ( "top_south_2", [ 1,-1, 1] ),
"periodic_any": ( "bottom_south_2", [ 1, 1, 1] ),
"any_periodic": ( "top_north_0", [ 1, 1, 1] ),
"any_any": [ ("top_south_0", [ 1, 1, 1]), ("top_south_2", [ 1, 1, 1]) ],
},
# BOTTOM TOP WEST EAST COMBS
"bottom_west": {
"periodic_periodic": ( "top_east_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "bottom_west_1", [-1, 1,-1] ),
"symmetry_any": ( "bottom_west_0", [ 1, 1,-1] ),
"any_symmetry": ( "bottom_west_2", [-1, 1, 1] ),
"periodic_any": ( "top_west_2", [ 1, 1, 1] ),
"any_periodic": ( "bottom_east_0", [ 1, 1, 1] ),
"any_any": [ ("bottom_west_0", [ 1, 1, 1]), ("bottom_west_2", [ 1, 1, 1]) ],
},
"bottom_east": {
"periodic_periodic": ( "top_west_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "bottom_east_1", [-1, 1,-1] ),
"symmetry_any": ( "bottom_east_2", [ 1, 1,-1] ),
"any_symmetry": ( "bottom_east_0", [-1, 1, 1] ),
"periodic_any": ( "top_east_0", [ 1, 1, 1] ),
"any_periodic": ( "bottom_west_2", [ 1, 1, 1] ),
"any_any": [ ("bottom_east_0", [ 1, 1, 1]), ("bottom_east_2", [ 1, 1, 1]) ],
},
"top_east": {
"periodic_periodic": ( "bottom_west_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "top_east_1", [-1, 1,-1] ),
"symmetry_any": ( "top_east_0", [ 1, 1,-1] ),
"any_symmetry": ( "top_east_2", [-1, 1, 1] ),
"periodic_any": ( "bottom_east_2", [ 1, 1, 1] ),
"any_periodic": ( "top_west_0", [ 1, 1, 1] ),
"any_any": [ ("top_east_0", [ 1, 1, 1]), ("top_east_2", [ 1, 1, 1]) ],
},
"top_west": {
"periodic_periodic": ( "bottom_east_1", [ 1, 1, 1] ),
"symmetry_symmetry": ( "top_west_1", [-1, 1,-1] ),
"symmetry_any": ( "top_west_2", [ 1, 1,-1] ),
"any_symmetry": ( "top_west_0", [-1, 1, 1] ),
"periodic_any": ( "bottom_west_0", [ 1, 1, 1] ),
"any_periodic": ( "top_east_2", [ 1, 1, 1] ),
"any_any": [ ("top_west_0", [ 1, 1, 1]), ("top_west_2", [ 1, 1, 1]) ],
}
}
[docs]
def assign_corners(self) -> Dict:
"""Identifies the boundary type pairs at the corners (2D) / edges (3D) of
computational domain. This is necessary to fill the halo cells that are located at
the diagonal extension of the domain.
:return: Dictionary containing the boundary type pairs at each boundary corner/edge location
:rtype: Dict
"""
locations = ["west_south", "west_north", "east_north", "east_south", "bottom_south", "bottom_north", "top_south", "top_north", "bottom_east", "bottom_west", "top_east", "top_west"]
indices = {
"west_south" : {"west" : 0, "south" : 0},
"west_north" : {"west" : -1, "north" : 0},
"east_south" : {"east" : 0, "south" : -1},
"east_north" : {"east" : -1, "north" : -1},
"bottom_south" : {"bottom" : 0, "south" : 0},
"bottom_north" : {"bottom" : -1, "north" : 0},
"top_south" : {"top" : 0, "south" : -1},
"top_north" : {"top" : -1, "north" : -1},
"bottom_east" : {"bottom" : -1, "east" : 0},
"bottom_west" : {"bottom" : 0, "west" : 0},
"top_east" : {"top" : -1, "east" : -1},
"top_west" : {"top" : 0, "west" : -1},
}
boundary_types = {}
boundary_types["primes"] = self.boundary_types_primes
boundary_types["levelset"] = self.boundary_types_levelset if self.levelset_type != None else None
corners = {"primes": {}, "levelset": {}} if self.levelset_type != None else {"primes": {}}
for key in corners:
for location in locations:
b_type = []
for loc in location.split("_"):
boundary_type = boundary_types[key][loc]
if type(boundary_type) == list:
b_type.append(boundary_type[0][indices[location][loc]])
else:
b_type.append(boundary_type)
b_type = "_".join(b_type)
corners[key][location] = b_type
for corner, combinations in corners[key].items():
boundary1, boundary2 = combinations.split("_")
if np.array([bound in ["dirichlet", "neumann", "wall"] for bound in [boundary1, boundary2]]).all():
boundary1, boundary2 = "any", "any"
elif boundary1 in ["dirichlet", "neumann", "wall", "symmetry"] and boundary2 == "periodic":
boundary1 = "any"
elif boundary1 in ["dirichlet", "neumann", "wall", "periodic"] and boundary2 == "symmetry":
boundary1 = "any"
elif boundary2 in ["dirichlet", "neumann", "wall", "symmetry"] and boundary1 == "periodic":
boundary2 = "any"
elif boundary2 in ["dirichlet", "neumann", "wall", "periodic"] and boundary1 == "symmetry":
boundary2 = "any"
corners[key][corner] = "_".join([boundary1, boundary2])
return corners
[docs]
def fill_boundary_primes(self, cons: jnp.ndarray, primes: jnp.ndarray,
current_time: float) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Fills the halo cells of the primitive and conservative variable buffers.
:param cons: Buffer of conservative variables
:type cons: jnp.ndarray
:param primes: Buffer of primitive variables
:type primes: jnp.ndarray
:param current_time: Current physical simulation time
:type current_time: float
:return: Primitive and conservative variable buffer with filled halo cells
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
# FILL BOUNDARIES
for boundary_location, boundary_types in self.boundary_types_primes.items():
if type(boundary_types) == list:
b_types = boundary_types[0]
b_ranges = boundary_types[1]
else:
b_types = [boundary_types]
b_ranges = [(0.0, 1.0)]
wall_counter = 0
dirichlet_counter = 0
neumann_counter = 0
for b_type, b_range in zip(b_types, b_ranges):
if b_type == "symmetry":
cons, primes = self.symmetry(cons, primes, boundary_location, b_range)
if b_type == "periodic":
cons, primes = self.periodic(cons, primes, boundary_location, b_range)
if b_type == "wall":
if type(self.wall_velocity_functions[boundary_location]) == list:
function = self.wall_velocity_functions[boundary_location][wall_counter]
else:
function = self.wall_velocity_functions[boundary_location]
cons, primes = self.wall(cons, primes, boundary_location, function, current_time, b_range)
wall_counter += 1
if b_type == "dirichlet":
if type(self.dirichlet_functions[boundary_location]) == list:
function = self.dirichlet_functions[boundary_location][dirichlet_counter]
else:
function = self.dirichlet_functions[boundary_location]
cons, primes = self.dirichlet(cons, primes, boundary_location, function, current_time, b_range)
dirichlet_counter += 1
if b_type == "neumann":
if type(self.neumann_functions[boundary_location]) == list:
function = self.neumann_functions[boundary_location][neumann_counter]
else:
function = self.neumann_functions[boundary_location]
cons, primes = self.neumann(cons, primes, boundary_location, function, current_time, b_range)
neumann_counter += 1
if b_type == "inactive":
continue
# FILL DOMAIN CORNERS
velocities = self.fill_corners_primes(primes[1:4])
primes = primes.at[1:4].set(velocities)
cons = get_conservatives_from_primitives(primes, self.material_manager)
return cons, primes
[docs]
def fill_boundary_levelset(self, levelset: jnp.ndarray) -> jnp.ndarray:
"""Fills the levelset buffer halo cells.
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:return: Levelset buffer with filled halo cells
:rtype: jnp.ndarray
"""
for boundary_location, boundary_type in self.boundary_types_levelset.items():
if boundary_type == "inactive":
continue
elif boundary_type in ["symmetry", "periodic", "neumann"]:
slices_retrieve = self.slices_retrieve[boundary_type][boundary_location]
slices_fill = self.slices_fill[boundary_location]
levelset = levelset.at[slices_fill].set(levelset[slices_retrieve])
levelset = self.fill_corners_levelset(levelset)
return levelset
[docs]
def fill_corners_levelset(self, levelset: jnp.ndarray) -> jnp.ndarray:
"""Fills the levelset buffer halo cells that are located at the diagional extension of the domain.
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:return: Levelset buffer with filled halo cells at the corners
:rtype: jnp.ndarray
"""
for location_fill, combinations in self.corners["levelset"].items():
if "inactive" in combinations.split("_"):
continue
if combinations == "any_any":
block1, block2 = self.corner_combinations[location_fill][combinations]
location_retrieve1, flip1 = block1
location_retrieve2, flip2 = block2
slice_fill = self.corner_slices_fill[location_fill]
slice_retrieve1 = self.corner_slices_retrieve[location_retrieve1]
slice_retrieve2 = self.corner_slices_retrieve[location_retrieve2]
halo = 0.5 * (levelset[slice_retrieve1][::flip1[0], ::flip1[1], ::flip1[2]] + levelset[slice_retrieve2][::flip2[0], ::flip2[1], ::flip2[2]])
levelset = levelset.at[slice_fill].set(halo)
else:
location_retrieve, flip = self.corner_combinations[location_fill][combinations]
slice_fill = self.corner_slices_fill[location_fill]
slice_retrieve = self.corner_slices_retrieve[location_retrieve]
levelset = levelset.at[slice_fill].set(levelset[slice_retrieve][..., ::flip[0], ::flip[1], ::flip[2]])
return levelset
[docs]
def fill_corners_primes(self, primes: jnp.ndarray) -> jnp.ndarray:
"""Fills the prime buffer halo cells that are located at the diagonal extension of the domain.
:param primes: Buffer of the primitive variables
:type primes: jnp.ndarray
:return: Primitive variables buffer with filled halo cells at the corners
:rtype: jnp.ndarray
"""
for location_fill, combinations in self.corners["primes"].items():
if "inactive" in combinations.split("_"):
continue
# FOR ANY - ANY, WE FILL THE CORNER USING THE AVERAGE OF THE ADJESCENT SQUARES
if combinations == "any_any":
block1, block2 = self.corner_combinations[location_fill][combinations]
location_retrieve1, flip1 = block1
location_retrieve2, flip2 = block2
slice_fill = self.corner_slices_fill[location_fill]
slice_retrieve1 = self.corner_slices_retrieve[location_retrieve1]
slice_retrieve2 = self.corner_slices_retrieve[location_retrieve2]
halo = 0.5 * (primes[slice_retrieve1][..., ::flip1[0], ::flip1[1], ::flip1[2]] + primes[slice_retrieve2][..., ::flip2[0], ::flip2[1], ::flip2[2]])
primes = primes.at[slice_fill].set(halo)
else:
# FOR SYMMETRY COMBINATIONS, THE SIGN OF CORRESPONDING VELOCITY MUST BE CHANGED
if "symmetry" in combinations.split("_"):
location_retrieve, flip = self.corner_combinations[location_fill][combinations]
slice_fill = self.corner_slices_fill[location_fill]
slice_retrieve = self.corner_slices_retrieve[location_retrieve]
indices_flip = list(np.where(np.array(flip) == -1)[0])
indices_noflip = [i for i in range(3) if i not in indices_flip]
fill_flip = (indices_flip,) + slice_fill
fill_noflip = (indices_noflip,) + slice_fill
retrieve_flip = (indices_flip,) + slice_retrieve
retrieve_noflip = (indices_noflip,) + slice_retrieve
primes = primes.at[fill_flip].set(-primes[retrieve_flip][..., ::flip[0], ::flip[1], ::flip[2]])
primes = primes.at[fill_noflip].set(primes[retrieve_noflip][..., ::flip[0], ::flip[1], ::flip[2]])
else:
location_retrieve, flip = self.corner_combinations[location_fill][combinations]
slice_fill = self.corner_slices_fill[location_fill]
slice_retrieve = self.corner_slices_retrieve[location_retrieve]
primes = primes.at[slice_fill].set(primes[slice_retrieve][..., ::flip[0], ::flip[1], ::flip[2]])
return primes
[docs]
def wall(self, cons: jnp.ndarray, primes: jnp.ndarray, location: str,
functions: Dict, current_time: float, b_range: List) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Fills the halo cells of the primitive and conservative variable buffer at the specified
location according to the no-slip wall boundary condition.
:param cons: Conservative variable buffer
:type cons: jnp.ndarray
:param primes: Primitive variable buffer
:type primes: jnp.ndarray
:param location: Boundary location
:type location: str
:param functions: Wall velocity functions
:type functions: Dict
:param current_time: Current physical simulation time
:type current_time: float
:param b_range: List containing the spatial range of the boundary at the specified location
:type b_range: List
:return: Primitive and conservative variable buffers with filled halos at specified location
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
if self.dim == 2:
slices_retrieve = self.get_slices_retrieve(location, "symmetry", b_range)
slices_fill = self.get_slices_fill(location, b_range)
else:
slices_retrieve = self.slices_retrieve["symmetry"][location]
slices_fill = self.slices_fill[location]
wall_velocity = {}
for velocity in ["u", "v", "w"]:
if velocity in functions.keys():
if type(functions[velocity]) == types.LambdaType:
wall_velocity[velocity] = functions[velocity](self.unit_handler.dimensionalize(current_time, "time"))
else:
wall_velocity[velocity] = functions[velocity]
wall_velocity[velocity] = self.unit_handler.non_dimensionalize(wall_velocity[velocity], "velocity")
else:
wall_velocity[velocity] = 0.0
u_halo = 2 * wall_velocity["u"] - primes[(jnp.s_[1:2],) + slices_retrieve]
v_halo = 2 * wall_velocity["v"] - primes[(jnp.s_[2:3],) + slices_retrieve]
w_halo = 2 * wall_velocity["w"] - primes[(jnp.s_[3:4],) + slices_retrieve]
halos_prime = jnp.vstack([ primes[(jnp.s_[0:1],) + slices_retrieve],
u_halo,
v_halo,
w_halo,
primes[(jnp.s_[4:5],) + slices_retrieve] ])
halos_cons = get_conservatives_from_primitives(halos_prime, self.material_manager)
cons = cons.at[slices_fill].set(halos_cons)
primes = primes.at[slices_fill].set(halos_prime)
return cons, primes
[docs]
def dirichlet(self, cons: jnp.ndarray, primes: jnp.ndarray, location: str,
functions: Union[Callable, float], current_time: float, b_range: List) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Fills the halo cells of the primitive and conservative variable buffer at the specified location
according to the dirichlet boundary condition.
:param cons: Conservative variable buffer
:type cons: jnp.ndarray
:param primes: Primitive variable buffer
:type primes: jnp.ndarray
:param location: Boundary location
:type location: str
:param functions: Dirichlet functions
:type functions: Union[Callable, float]
:param current_time: Current physical simulation time
:type current_time: float
:param b_range: List containing the spatial range of the boundary at the specified location
:type b_range: List
:return: Primitive and conservative variable buffers with filled halos at specified location
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
# GET SLICE OBJECTS
if self.dim == 2:
slices_fill = self.get_slices_fill(location, b_range)
else:
slices_fill = self.slices_fill[location]
# COMPUTE PRESENT COORDINATES
coordinates = [self.coordinates_plane[location].get(axis)[int(b_range[0]*self.number_of_cells[self.spatial_axis_to_index[axis]]):int(b_range[1]*self.number_of_cells[self.spatial_axis_to_index[axis]])] for axis in self.active_axis if self.coordinates_plane[location].get(axis) != None]
coordinates_name = [axis for axis in self.active_axis if self.coordinates_plane[location].get(axis) != None]
axis_to_expand = [axis for axis in ["x", "y", "z"] if axis not in coordinates_name]
# DIMENSIONALIZE FOR LAMBDA FUNCTION
mesh_grid = jnp.meshgrid(*[self.unit_handler.dimensionalize(coord, "length") for coord in coordinates], indexing="ij")
current_time = self.unit_handler.dimensionalize(current_time, "time")
# EVALUATE LAMBDAS
halos_prime_list = []
for prime_state in functions:
func = functions[prime_state]
if type(func) in [float, np.float64, np.float32]:
halos = func*jnp.ones(mesh_grid[0].shape) if self.dim != 1 else func
elif type(func) == types.LambdaType:
halos = func(*mesh_grid, current_time)
else:
assert False, "Dirichlet boundary values must be lambda function or python/numpy float"
for ax in axis_to_expand:
halos = jnp.expand_dims(halos, self.spatial_axis_to_index[ax])
halos = self.unit_handler.non_dimensionalize(halos, prime_state)
halos_prime_list.append(halos)
# STACK
halos_prime = jnp.stack(halos_prime_list, axis=0)
if self.levelset_type == "FLUID-FLUID":
halos_prime = jnp.stack([halos_prime, halos_prime], axis=1)
# COMPUTE CONSERVATIVES
halos_cons = get_conservatives_from_primitives(halos_prime, self.material_manager)
# FILL
primes = primes.at[slices_fill].set(halos_prime)
cons = cons.at[slices_fill].set(halos_cons)
return cons, primes
[docs]
def neumann(self, cons: jnp.ndarray, primes: jnp.ndarray, location: str,
functions: Union[Callable, float], current_time: float, b_range: List) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Fills the halo cells of the primitive and conservative variable buffer at the specified
location according to the neumann boundary condition.
:param cons: Conservative variable buffer
:type cons: jnp.ndarray
:param primes: Primitive variable buffer
:type primes: jnp.ndarray
:param location: Boundary location
:type location: str
:param functions: Neumann functions
:type functions: Union[Callable, float]
:param current_time: Current physical simulation time
:type current_time: float
:param b_range: List containing the spatial range of the boundary at the specified location
:type b_range: List
:return: Primitive and conservative variable buffers with filled halos at specified location
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
# GET SLICE OBJECTS
if self.dim == 2:
slices_retrieve = self.get_slices_retrieve(location, "neumann", b_range)
slices_fill = self.get_slices_fill(location, b_range)
else:
slices_retrieve = self.slices_retrieve["neumann"][location]
slices_fill = self.slices_fill[location]
# COMPUTE PRESENT COORDINATES
coordinates = [self.coordinates_plane[location].get(axis)[int(b_range[0]*self.number_of_cells[self.spatial_axis_to_index[axis]]):int(b_range[1]*self.number_of_cells[self.spatial_axis_to_index[axis]])] for axis in self.active_axis if self.coordinates_plane[location].get(axis) != None]
coordinates_name = [axis for axis in self.active_axis if self.coordinates_plane[location].get(axis) != None]
axis_to_expand = [axis for axis in ["x", "y", "z"] if axis not in coordinates_name]
# DIMENSIONALIZE FOR LAMBDA FUNCTION
mesh_grid = jnp.meshgrid(*[self.unit_handler.dimensionalize(coord, "length") for coord in coordinates], indexing="ij")
current_time = self.unit_handler.dimensionalize(current_time, "time")
# EVALUATE LAMBDAS
halos_prime_list = []
for i, prime_state in enumerate(functions):
func = functions[prime_state]
if type(func) in [float, np.float64, np.float32]:
neumann_value = func*jnp.ones(mesh_grid[0].shape) if self.dim != 1 else func
elif type(func) == types.LambdaType:
neumann_value = func(*mesh_grid, current_time)
else:
assert False, "Neumann boundary values must be lambda function or python/numpy float"
for axis in axis_to_expand:
neumann_value = jnp.expand_dims(neumann_value, self.spatial_axis_to_index[axis])
neumann_value = self.unit_handler.non_dimensionalize(neumann_value, prime_state)
neumann_value = self.unit_handler.dimensionalize(neumann_value, "length")
halos = primes[(i,) + slices_retrieve] + self.upwind_difference_sign[location] * neumann_value * self.cell_sizes[location]
halos_prime_list.append(halos)
# STACK
halos_prime = jnp.stack(halos_prime_list, axis=0)
# COMPUTE CONSERVATIVES
halos_cons = get_conservatives_from_primitives(halos_prime, self.material_manager)
# FILL
primes = primes.at[slices_fill].set(halos_prime)
cons = cons.at[slices_fill].set(halos_cons)
return cons, primes
[docs]
def symmetry(self, cons: jnp.ndarray, primes: jnp.ndarray,
location: str, b_range: List) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Fills the halo cells of the primitive and conservative variable buffer
at the specified location according to the symmetric boundary condition.
:param cons: Conservative variable buffer
:type cons: jnp.ndarray
:param primes: Primitive variable buffer
:type primes: jnp.ndarray
:param location: Boundary location
:type location: str
:param b_range: List containing the spatial range of the boundary at the specified location
:type b_range: List
:return: Primitive and conservative variable buffers with filled halos at specified location
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
if self.dim == 2:
slices_retrieve = self.get_slices_retrieve(location, "symmetry", b_range)
slices_fill = self.get_slices_fill(location, b_range)
else:
slices_retrieve = self.slices_retrieve["symmetry"][location]
slices_fill = self.slices_fill[location]
axis = self.location_to_axis[location]
cons = cons.at[(self.symmetry_indices[axis][0], ) + slices_fill].set(cons[(self.symmetry_indices[axis][0], ) + slices_retrieve])
cons = cons.at[(self.symmetry_indices[axis][1], ) + slices_fill].set(-cons[(self.symmetry_indices[axis][1], ) + slices_retrieve])
primes = primes.at[(self.symmetry_indices[axis][0], ) + slices_fill].set(primes[(self.symmetry_indices[axis][0], ) + slices_retrieve])
primes = primes.at[(self.symmetry_indices[axis][1], ) + slices_fill].set(-primes[(self.symmetry_indices[axis][1], ) + slices_retrieve])
return cons, primes
[docs]
def periodic(self, cons: jnp.ndarray, primes: jnp.ndarray, location: str, b_range: List) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Fills the halos of the conservative and primitive variable buffer at
the specified location according to the periodic boundary condition.
:param cons: Conservative variable buffer
:type cons: jnp.ndarray
:param primes: Primitive variable buffer
:type primes: jnp.ndarray
:param location: Boundary location
:type location: str
:param b_range: List containing the spatial range of the boundary at the specified location
:type b_range: List
:return: Primitive and conservative variable buffers with filled halos at specified location
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
if self.dim == 2:
slices_retrieve = self.get_slices_retrieve(location, "periodic", b_range)
slices_fill = self.get_slices_fill(location, b_range)
else:
slices_retrieve = self.slices_retrieve["periodic"][location]
slices_fill = self.slices_fill[location]
cons = cons.at[slices_fill].set(cons[slices_retrieve])
primes = primes.at[slices_fill].set(primes[slices_retrieve])
return cons, primes
[docs]
def get_slices_fill(self, location: str, b_range: List) -> Tuple:
"""Computes the slice objects to fill the halos depending on the range at the specified boundary location.
:param location: Boundary location
:type location: str
:param b_range: List containing the spatial range of the boundary at the specified location
:type b_range: List
:return: Slice objects
:rtype: Tuple
"""
# 2D only
axis = [axis for axis in self.active_axis if self.coordinates_plane[location].get(axis) != None][0]
a = self.nh + int(b_range[0] * self.number_of_cells[self.spatial_axis_to_index[axis]])
b = self.nh + int(b_range[1] * self.number_of_cells[self.spatial_axis_to_index[axis]])
slices = list(self.slices_fill[location])
slices[self.spatial_axis_to_index_for_slices[axis]] = jnp.s_[a:b]
return tuple(slices)
[docs]
def get_slices_retrieve(self, location: str, boundary_type: str, b_range: List) -> Tuple:
"""Computes the slice objects to retrieve the values depending on the range at the specified boundary location.
:param location: Boundary location
:type location: str
:param boundary_type: Boundary location
:type boundary_type: str
:param b_range: List containing the spatial range of the boundary at the specified location
:type b_range: List
:return: Slice objects
:rtype: Tuple
"""
# 2D only
axis = [axis for axis in self.active_axis if self.coordinates_plane[location].get(axis) != None][0]
a = self.nh + int(b_range[0] * self.number_of_cells[self.spatial_axis_to_index[axis]])
b = self.nh + int(b_range[1] * self.number_of_cells[self.spatial_axis_to_index[axis]])
slices = list(self.slices_retrieve[boundary_type][location])
slices[self.spatial_axis_to_index_for_slices[axis]] = jnp.s_[a:b]
return tuple(slices)
[docs]
def symmetry_levelset(self, levelset: jnp.ndarray, location: str):
slices_retrieve = self.slices_retrieve["symmetry"][location]
slices_fill = self.slices_fill[location]
levelset = levelset.at[slices_fill].set(levelset[slices_retrieve])
return levelset
[docs]
def periodic_levelset(self, levelset: jnp.ndarray, location: str):
slices_retrieve = self.slices_retrieve["periodic"][location]
slices_fill = self.slices_fill[location]
levelset = levelset.at[slices_fill].set(levelset[slices_retrieve])
return levelset