#*------------------------------------------------------------------------------*
#* JAX-FLUIDS - *
#* *
#* A fully-differentiable CFD solver for compressible two-phase flows. *
#* Copyright (C) 2022 Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams *
#* *
#* This program is free software: you can redistribute it and/or modify *
#* it under the terms of the GNU General Public License as published by *
#* the Free Software Foundation, either version 3 of the License, or *
#* (at your option) any later version. *
#* *
#* This program is distributed in the hope that it will be useful, *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
#* GNU General Public License for more details. *
#* *
#* You should have received a copy of the GNU General Public License *
#* along with this program. If not, see <https://www.gnu.org/licenses/>. *
#* *
#*------------------------------------------------------------------------------*
#* *
#* CONTACT *
#* *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de *
#* *
#*------------------------------------------------------------------------------*
#* *
#* Munich, April 15th, 2022 *
#* *
#*------------------------------------------------------------------------------*
from functools import partial
from typing import List, Tuple, Union, Dict
import types
import jax
import jax.numpy as jnp
from jaxfluids.boundary_condition import BoundaryCondition
from jaxfluids.domain_information import DomainInformation
from jaxfluids.levelset.helper_functions import move_source_to_target_ii, move_source_to_target_ij, move_source_to_target_ijk, move_target_to_source_ii, move_target_to_source_ij, move_target_to_source_ijk
from jaxfluids.levelset.interface_quantity_computer import InterfaceQuantityComputer
from jaxfluids.levelset.interface_flux_computer import InterfaceFluxComputer
from jaxfluids.levelset.levelset_reinitializer import LevelsetReinitializer
from jaxfluids.materials.material_manager import MaterialManager
from jaxfluids.levelset.quantity_extender import QuantityExtender
from jaxfluids.levelset.geometry_calculator import GeometryCalculator
from jaxfluids.stencils.spatial_derivative import SpatialDerivative
from jaxfluids.unit_handler import UnitHandler
from jaxfluids.utilities import get_conservatives_from_primitives, get_primitives_from_conservatives
from jaxfluids.stencils import DICT_FIRST_DERIVATIVE_CENTER, DICT_SECOND_DERIVATIVE_CENTER, DICT_DERIVATIVE_LEVELSET_ADVECTION, DICT_DERIVATIVE_QUANTITY_EXTENDER, DICT_DERIVATIVE_REINITIALIZATION
from jaxfluids.time_integration import DICT_TIME_INTEGRATION
[docs]
class LevelsetHandler():
""" The LevelsetHandler class manages computations to perform two-phase simulations using the levelset method.
The main functionality includes
- Transformation of the conservative states from volume-averages to actual conserved quantities according to the volume fraction
- Weighting of the cell face fluxes according to the apertures
- Computation of the interface fluxes
- Computation of the levelset advection right-hand-side
- Extension of the primitive state from the real fluid cells to the ghost fluid cells
- Mixing of the integrated conservatives
- Computation of geometrical quantities, i.e., volume fraction, apertures and real fluid/cut cell masks
"""
eps = jnp.finfo(jnp.float64).eps
def __init__(self, domain_information: DomainInformation, numerical_setup: Dict, material_manager: MaterialManager,
unit_handler: UnitHandler, solid_interface_velocity: Union[List, types.LambdaType], boundary_condition: BoundaryCondition) -> None:
self.geometry_calculator = GeometryCalculator(
domain_information = domain_information,
first_derivative_stencil = DICT_FIRST_DERIVATIVE_CENTER[numerical_setup["levelset"]["geometry_calculator_stencil"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis, offset=numerical_setup["levelset"]["halo_cells"]),
second_derivative_stencil = DICT_SECOND_DERIVATIVE_CENTER[numerical_setup["levelset"]["geometry_calculator_stencil"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis, offset=numerical_setup["levelset"]["halo_cells"]),
subcell_reconstruction = numerical_setup["levelset"]["subcell_reconstruction"]
)
self.interface_quantity_computer = InterfaceQuantityComputer(
domain_information = domain_information,
material_manager = material_manager,
unit_handler = unit_handler,
solid_interface_velocity = solid_interface_velocity,
numerical_setup = numerical_setup
)
self.interface_flux_computer = InterfaceFluxComputer(
domain_information = domain_information,
material_manager = material_manager,
numerical_setup = numerical_setup
)
self.extender_primes = QuantityExtender(
domain_information = domain_information,
boundary_condition = boundary_condition,
time_integrator = DICT_TIME_INTEGRATION[numerical_setup["levelset"]["extension"]["time_integrator"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis),
spatial_stencil = DICT_DERIVATIVE_QUANTITY_EXTENDER[numerical_setup["levelset"]["extension"]["spatial_stencil"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis),
is_interface = False,
)
self.extender_interface = QuantityExtender(
domain_information = domain_information,
boundary_condition = boundary_condition,
time_integrator = DICT_TIME_INTEGRATION[numerical_setup["levelset"]["extension"]["time_integrator"]](nh=domain_information.nh_geometry, inactive_axis=domain_information.inactive_axis),
spatial_stencil = DICT_DERIVATIVE_QUANTITY_EXTENDER[numerical_setup["levelset"]["extension"]["spatial_stencil"]](nh=domain_information.nh_geometry, inactive_axis=domain_information.inactive_axis),
is_interface = True,
)
self.levelset_reinitializer = LevelsetReinitializer(
domain_information = domain_information,
boundary_condition = boundary_condition,
time_integrator = DICT_TIME_INTEGRATION[numerical_setup["levelset"]["reinitialization"]["time_integrator"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis),
derivative_stencil = DICT_DERIVATIVE_REINITIALIZATION[numerical_setup["levelset"]["reinitialization"]["spatial_stencil"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis),
)
self.levelset_reinitializer_init = LevelsetReinitializer(
domain_information = domain_information,
boundary_condition = boundary_condition,
time_integrator = DICT_TIME_INTEGRATION[numerical_setup["levelset"]["reinitialization"]["time_integrator_init"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis),
derivative_stencil = DICT_DERIVATIVE_REINITIALIZATION[numerical_setup["levelset"]["reinitialization"]["spatial_stencil_init"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis),
)
self.levelset_advection_stencil : SpatialDerivative = DICT_DERIVATIVE_LEVELSET_ADVECTION[numerical_setup["levelset"]["levelset_advection_stencil"]](nh=domain_information.nh_conservatives, inactive_axis=domain_information.inactive_axis)
self.material_manager = material_manager
# DOMAIN INFORMATION PARAMETERS
self.dim = domain_information.dim
self.cell_sizes = domain_information.cell_sizes
self.cell_centers = domain_information.cell_centers
self.nhx, self.nhy, self.nhz = domain_information.domain_slices_conservatives
self.nhx_, self.nhy_, self.nhz_ = domain_information.domain_slices_geometry
self.nh = domain_information.nh_conservatives
self.nh_ = domain_information.nh_geometry
self.nhx__, self.nhy__, self.nhz__ = domain_information.domain_slices_conservatives_to_geometry
self.active_axis_indices = domain_information.active_axis_indices
self.inactive_axis_indices = domain_information.inactive_axis_indices
self.smallest_cell_size = jnp.min(jnp.array([self.cell_sizes[i] for i in self.active_axis_indices]))
# GENERAL LEVELSET
self.levelset_type = numerical_setup["levelset"]["interface_interaction"]
self.volume_fraction_threshold = numerical_setup["levelset"]["volume_fraction_threshold"]
self.narrow_band_cutoff = numerical_setup["levelset"]["narrow_band_cutoff"]
self.narrow_band_computations = numerical_setup["levelset"]["narrow_band_computations"]
self.mixing_targets = numerical_setup["levelset"]["mixing_targets"] if "mixing_targets" in numerical_setup["levelset"].keys() else self.dim
# EXTENSION PARAMETERS
self.steps_extension_primes = numerical_setup["levelset"]["extension"]["steps_primes"]
self.CFL_extension_primes = numerical_setup["levelset"]["extension"]["CFL_primes"]
self.steps_extension_interface = numerical_setup["levelset"]["extension"]["steps_interface"]
self.CFL_extension_interface = numerical_setup["levelset"]["extension"]["CFL_interface"]
# REINITIALIZATION PARAMETERS
self.steps_reinitialization = numerical_setup["levelset"]["reinitialization"]["steps"]
self.steps_reinitialization_init = numerical_setup["levelset"]["reinitialization"]["steps_init"]
self.CFL_reinitialization = numerical_setup["levelset"]["reinitialization"]["CFL"]
self.CFL_reinitialization_init = numerical_setup["levelset"]["reinitialization"]["CFL_init"]
self.interval_reinitialization = numerical_setup["levelset"]["reinitialization"]["interval"]
self.cut_cell_reinitialization = numerical_setup["levelset"]["reinitialization"]["cut_cell"]
# ACTIVE PHYSICAL FLUXES
self.is_viscous_flux = numerical_setup["active_physics"]["is_viscous_flux"]
self.is_surface_tension = numerical_setup["active_physics"]["is_surface_tension"]
index_pairs = [(0,1), (0,2), (1,2)]
self.index_pairs_mixing = []
for pair in index_pairs:
if pair[0] in self.active_axis_indices and pair[1] in self.active_axis_indices:
self.index_pairs_mixing.append(pair)
[docs]
def compute_cut_cell_mask(self, volume_fraction: jnp.ndarray) -> jnp.ndarray:
"""Computes the cut cell mask, i.e., cells where the volume fraction is > 0.0 and < 1.0
:param volume_fraction: Volume fraction buffer
:type volume_fraction: jnp.ndarray
:return: Cut cell mask
:rtype: jnp.ndarray
"""
mask_cut_cells = jnp.where( (volume_fraction > 0.0) & (volume_fraction < 1.0), 1, 0)
return mask_cut_cells
[docs]
@partial(jax.jit, static_argnums=(0))
def compute_masks(self, levelset: jnp.ndarray,
volume_fraction: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Computes the real mask and the cut cell mask
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:param volume_fraction: Volume fraction buffer
:type volume_fraction: jnp.ndarray
:return: Tuple containing the real mask and the cut cell mask
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
if self.levelset_type == "FLUID-FLUID":
mask_positive = jnp.where( levelset[self.nhx__, self.nhy__, self.nhz__] > 0.0, 1, 0 )
mask_negative = jnp.where( levelset[self.nhx__, self.nhy__, self.nhz__] < 0.0, 1, 0 )
mask_cut_cells = self.compute_cut_cell_mask(volume_fraction)
mask_positive = jnp.maximum(mask_positive, mask_cut_cells)
mask_negative = jnp.maximum(mask_negative, mask_cut_cells)
mask_real = jnp.stack([mask_positive, mask_negative], axis=0)
else:
mask_positive = jnp.where( levelset[self.nhx__, self.nhy__, self.nhz__] > 0.0, 1, 0 )
mask_cut_cells = self.compute_cut_cell_mask(volume_fraction)
mask_real = jnp.maximum(mask_positive, mask_cut_cells)
return mask_real, mask_cut_cells
[docs]
def compute_volume_fraction_and_apertures(self, levelset: jnp.ndarray) -> Tuple[jnp.ndarray, List]:
"""Computes the volume fraction and apertures via linear interface reconstruction
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:return: Tuple containing the volume fraction buffer and the aperture buffers
:rtype: Tuple[jnp.ndarray, List]
"""
volume_fraction, apertures = self.geometry_calculator.linear_interface_reconstruction(levelset)
return volume_fraction, apertures
[docs]
def extend_primes(self, cons: jnp.ndarray, primes: jnp.ndarray, levelset: jnp.ndarray,
volume_fraction: jnp.ndarray, current_time: float,
mask_small_cells: jnp.ndarray = None) -> Tuple[jnp.ndarray, float]:
"""Extends the primitives from the real fluid cells into the ghost fluid cells within the narrow_band_compute.
Subsequently, the corresponding conservatives are computed from the extended primitives.
:param cons: Buffer of conservative variables
:type cons: jnp.ndarray
:param primes: Buffer of primitive variables
:type primes: jnp.ndarray
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:param volume_fraction: Volume fraction buffer
:type volume_fraction: jnp.ndarray
:param current_time: Current physical simulation time
:type current_time: float
:param mask_small_cells: Mask indicating small negative cells, defaults to None
:type mask_small_cells: jnp.ndarray, optional
:return: Tuple of primitive and conservative buffer and maximum extension residual
:rtype: Tuple[jnp.ndarray, float]
"""
# GEOMETRICAL QUANTITIES - WE EXTEND INTO GHOST CELLS PLUS SMALL CELLS INSIDE THE NARROW BAND
normal = self.geometry_calculator.compute_normal(levelset)[...,self.nhx_,self.nhy_,self.nhz_]
normal_extend = jnp.stack([-normal, normal], axis=1) if self.levelset_type == "FLUID-FLUID" else -normal
mask_real, _ = self.compute_masks(levelset, volume_fraction)
mask_ghost = 1 - mask_real[...,self.nhx_,self.nhy_,self.nhz_]
mask_narrow_band = jnp.where(jnp.abs(levelset[self.nhx,self.nhy,self.nhz])/self.smallest_cell_size < self.narrow_band_computations, 1, 0)
mask_extend = jnp.maximum(mask_ghost * mask_narrow_band, mask_small_cells) if mask_small_cells != None else mask_ghost * mask_narrow_band
# EXTEND PRIMES
primes, max_residual = self.extender_primes.extend(primes, normal_extend, mask_extend, self.CFL_extension_primes, self.steps_extension_primes)
# VELOCITY TREATMENT FOR SOLIDS
if self.levelset_type in ["FLUID-SOLID-STATIC", "FLUID-SOLID-DYNAMIC"]:
if self.is_viscous_flux:
primes = primes.at[1:4, self.nhx, self.nhy, self.nhz].mul(jnp.where(mask_ghost == 1.0, -1.0, 1.0))
else:
velocities = primes[1:4,...,self.nhx,self.nhy,self.nhz] * mask_ghost
velocities = velocities - normal * 2 * jnp.einsum('ijkl, ijkl -> jkl', normal, velocities)
primes = primes.at[1:4, ..., self.nhx, self.nhy, self.nhz].mul(mask_real[...,self.nhx_,self.nhy_,self.nhz_])
primes = primes.at[1:4, ..., self.nhx, self.nhy, self.nhz].add(velocities)
if self.levelset_type == "FLUID-SOLID-DYNAMIC":
interface_velocity = self.compute_solid_interface_velocity(current_time)
primes = primes.at[1:4, self.nhx, self.nhy, self.nhz].add(mask_ghost * interface_velocity)
# CUT OFF PRIMES IN GHOST CELLS IN WHICH WE DO NOT EXTEND, I.E. ALL GHOST CELLS OUTSIDE THE NARROW BAND
mask_cut_off = (1 - mask_narrow_band) * mask_ghost
primes = primes.at[...,self.nhx,self.nhy,self.nhz].mul(1 - mask_cut_off)
primes = primes.at[[0,4],...,self.nhx,self.nhy,self.nhz].add(mask_cut_off * self.eps)
# UPDATE CONSERVATIVES IN EXTENDED CELLS, I.E. GHOST CELLS PLUS SMALL CELLS INSIDE THE NARROW BAND
cons_in_extend = get_conservatives_from_primitives(primes[...,self.nhx,self.nhy,self.nhz], self.material_manager) * mask_extend
cons = cons.at[...,self.nhx,self.nhy,self.nhz].mul(1 - mask_extend)
cons = cons.at[...,self.nhx,self.nhy,self.nhz].add(cons_in_extend)
return cons, primes, max_residual
[docs]
def weight_cell_face_flux_xi(self, flux_xi: jnp.ndarray, apertures: jnp.ndarray) -> jnp.ndarray:
"""Weights the cell face fluxes according to the apertures.
:param flux_xi: Cell face flux at xi
:type flux_xi: jnp.ndarray
:param apertures: Aperture buffer
:type apertures: jnp.ndarray
:return: Weighted cell face flux at xi
:rtype: jnp.ndarray
"""
if self.levelset_type == "FLUID-FLUID":
apertures = jnp.stack([apertures, 1.0 - apertures], axis=0)
flux_xi *= apertures[...,self.nhx_,self.nhy_,self.nhz_]
return flux_xi
[docs]
def weight_volume_force(self, volume_force: jnp.ndarray, volume_fraction: jnp.ndarray) -> jnp.ndarray:
"""Weights the volume forces according to the volume fraction.
:param volume_force: Volume force buffer
:type volume_force: jnp.ndarray
:param volume_fraction: Volume fraction buffer
:type volume_fraction: jnp.ndarray
:return: Weighted volume force
:rtype: jnp.ndarray
"""
if self.levelset_type == "FLUID-FLUID":
volume_fraction = jnp.stack([volume_fraction, 1.0 - volume_fraction], axis=0)
volume_force *= volume_fraction[...,self.nhx_,self.nhy_,self.nhz_]
return volume_force
[docs]
@partial(jax.jit, static_argnums=(0))
def compute_solid_interface_velocity(self, current_time: float) -> jnp.ndarray:
"""Computes the interface velocity for FLUID-SOLID-DYNAMIC interface interaction.
:param current_time: Current physical simulation time
:type current_time: float
:return: Interface velocity buffer
:rtype: jnp.ndarray
"""
interface_velocity = self.interface_quantity_computer.compute_solid_interface_velocity(current_time)
return interface_velocity
[docs]
@partial(jax.jit, static_argnums=(0))
def compute_interface_quantities(self, primes: jnp.ndarray, levelset: jnp.ndarray,
volume_fraction: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, float]:
"""Computes interface velocity and pressure for FLUID-FLUID interface interaction and
extends the values into the narrow_band_compute.
:param primes: Buffer of primitive variables
:type primes: jnp.ndarray
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:param volume_fraction: Volume fractio buffer
:type volume_fraction: jnp.ndarray
:return: Tuple of interface velocity, interface pressure and maximum residual of the extension
:rtype: Tuple[jnp.ndarray, jnp.ndarray, float]
"""
# GEOMTRICAL QUANTITIES
normal = self.geometry_calculator.compute_normal(levelset)
curvature = self.geometry_calculator.compute_curvature(levelset) if self.is_surface_tension else None
mask_cut_cells = self.compute_cut_cell_mask(volume_fraction)
# COMPUTE INTERFACE QUANTITIES
interface_velocity, interface_pressure = self.interface_quantity_computer.solve_interface_interaction(primes, normal, curvature)
interface_quantities = jnp.vstack([jnp.expand_dims(interface_velocity, axis=0), interface_pressure])
# EXTEND INTERFACE QUANTITIES
inverse_cut_cell_mask = 1 - mask_cut_cells[...,self.nhx_,self.nhy_,self.nhz_]
mask_narrow_band = jnp.where(jnp.abs(levelset[self.nhx,self.nhy,self.nhz])/self.smallest_cell_size < self.narrow_band_computations, 1, 0)
mask_extend = inverse_cut_cell_mask * mask_narrow_band
normal_extend = normal[...,self.nhx_,self.nhy_,self.nhz_] * jnp.sign(levelset[self.nhx,self.nhy,self.nhz])
interface_quantities, max_residual = self.extender_interface.extend(interface_quantities, normal_extend, mask_extend, self.CFL_extension_interface, self.steps_extension_interface)
# CUT OFF NARROW BAND
interface_quantities = interface_quantities[...,self.nhx_,self.nhy_,self.nhz_] * mask_narrow_band
return interface_quantities[0], interface_quantities[1:], max_residual
[docs]
def compute_interface_flux_xi(self, primes: jnp.ndarray, levelset: jnp.ndarray, interface_velocity: jnp.ndarray,
interface_pressure: Union[jnp.ndarray, None], volume_fraction: jnp.ndarray, apertures: jnp.ndarray, axis: int) -> jnp.ndarray:
"""Computes the interface flux depending on the present interface interaction type.
:param primes: Buffer of primitive variables
:type primes: jnp.ndarray
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:param interface_velocity: Interface velocity buffer
:type interface_velocity: jnp.ndarray
:param interface_pressure: Interface pressure buffer
:type interface_pressure: Union[jnp.ndarray, None]
:param volume_fraction: Volume fraction buffer
:type volume_fraction: jnp.ndarray
:param apertures: Aperture buffer
:type apertures: jnp.ndarray
:param axis: Current axis
:type axis: int
:return: Interface fluxes
:rtype: jnp.ndarray
"""
normal = self.geometry_calculator.compute_normal(levelset)
interface_flux_xi = self.interface_flux_computer.compute_interface_flux_xi(primes, interface_velocity, interface_pressure, volume_fraction, apertures, normal, axis)
return interface_flux_xi
[docs]
def compute_levelset_advection_rhs(self, levelset: jnp.ndarray, interface_velocity: jnp.ndarray, axis: int) -> jnp.ndarray:
"""Computes the right-hand-side of the levelset advection equation.
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:param interface_velocity: Interface velocity buffer
:type interface_velocity: jnp.ndarray
:param axis: Current axis
:type axis: int
:return: right-hand-side contribution for current axis
:rtype: jnp.ndarray
"""
# GEOMETRICAL QUANTITIES
normal = self.geometry_calculator.compute_normal(levelset)
if self.levelset_type == "FLUID-FLUID":
mask_narrow_band = jnp.where(jnp.abs(levelset[self.nhx,self.nhy,self.nhz])/self.smallest_cell_size < self.narrow_band_computations, 1, 0)
else:
mask_narrow_band = jnp.ones_like(levelset[self.nhx,self.nhy,self.nhz])
# DERIVATIVE
derivative_L = self.levelset_advection_stencil.derivative_xi(levelset, 1.0, axis, 0)
derivative_R = self.levelset_advection_stencil.derivative_xi(levelset, 1.0, axis, 1)
# UPWINDING
velocity = interface_velocity * normal[axis,self.nhx_,self.nhy_,self.nhz_] if self.levelset_type == "FLUID-FLUID" else interface_velocity[axis]
mask_L = jnp.where(velocity >= 0.0, 1.0, 0.0)
mask_R = 1.0 - mask_L
# SUM RHS
rhs_contribution = - velocity * (mask_L * derivative_L + mask_R * derivative_R) / self.cell_sizes[axis]
rhs_contribution *= mask_narrow_band
return rhs_contribution
# @partial(jax.jit, static_argnums=(0, 2))
[docs]
def reinitialize(self, levelset: jnp.ndarray, initializer: bool) -> Tuple[jnp.ndarray, float]:
"""Reinitializes the levelset buffer and subsequently applies cut off
to values that lie outside the narrow_band_cutoff. If the initializer flag is True,
cut cells are always reinitialized. This is required to initialize
levelset fields that are provided as no signed distance functions, e.g., an ellipse.
If cut_cell_reinitialization is True (specified in numerical setup), then cut cells are also reinitialized.
The residual is only computed in the narrow_band_computations.
:param levelset: Levelset buffer
:type levelset: jnp.ndarray
:param initializer: Bool indicating if the reinitialization is performed in the
jaxfluids initializer or the jaxfluids simulation manager
:type initializer: bool
:return: Reinitialized levelset buffer and the corresponding maximum residual
:rtype: Tuple[jnp.ndarray, float]
"""
# TODO SPLIT UP INITIAL REINITIALIZATION AND SET CUT OFF
# REINITIALIZE
if initializer:
mask_reinitialize = jnp.ones_like(levelset[self.nhx,self.nhy,self.nhz], dtype=jnp.uint32)
mask_residual = jnp.ones_like(levelset[self.nhx,self.nhy,self.nhz], dtype=jnp.uint32)
levelset, max_residual = self.levelset_reinitializer_init.reinitialize(levelset, mask_reinitialize, mask_residual, self.CFL_reinitialization_init, self.steps_reinitialization_init)
# levelset, levelset_out = self.levelset_reinitializer_init.reinitialize_debug(levelset, mask_reinitialize, mask_residual, self.CFL_reinitialization_init, self.steps_reinitialization_init)
else:
mask_reinitialize = jnp.ones_like(levelset[self.nhx,self.nhy,self.nhz], dtype=jnp.uint32) \
if self.cut_cell_reinitialization else \
jnp.where( jnp.abs(levelset[self.nhx,self.nhy,self.nhz])/self.smallest_cell_size > jnp.sqrt(2.0)/2, 1, 0 )
mask_residual = jnp.where(jnp.abs(levelset[self.nhx,self.nhy,self.nhz]/self.smallest_cell_size) < self.narrow_band_computations, 1, 0)
mask_residual = mask_residual * mask_reinitialize
levelset, max_residual = self.levelset_reinitializer.reinitialize(levelset, mask_reinitialize, mask_residual, self.CFL_reinitialization, self.steps_reinitialization)
# CUT OFF MASKS
mask_cut_off_positive = jnp.where(levelset[self.nhx,self.nhy,self.nhz]/self.smallest_cell_size > self.narrow_band_cutoff, 1, 0)
mask_cut_off_negative = jnp.where(levelset[self.nhx,self.nhy,self.nhz]/self.smallest_cell_size < -self.narrow_band_cutoff, 1, 0)
# SET CUTOFF VALUES
levelset = levelset.at[self.nhx,self.nhy,self.nhz].mul(1 - mask_cut_off_positive)
levelset = levelset.at[self.nhx,self.nhy,self.nhz].mul(1 - mask_cut_off_negative)
levelset = levelset.at[self.nhx,self.nhy,self.nhz].add(mask_cut_off_positive * self.narrow_band_cutoff * self.smallest_cell_size)
levelset = levelset.at[self.nhx,self.nhy,self.nhz].add(-mask_cut_off_negative * self.narrow_band_cutoff * self.smallest_cell_size)
return levelset, max_residual
[docs]
def mixing(self, cons: jnp.ndarray, levelset_new: jnp.ndarray, volume_fraction_new: jnp.ndarray,
volume_fraction_old: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
""" Performs a mixing procedure on the integrated conservatives enabling stable integration of
small cut cells with the CFL criterion for full cells. For small cells that are in the process of vanishing
(becoming ghost cells), this procedure may produce negative mass/energy. To prevent numerical instability,
we track these cells with the mask_small_cells buffer and perform a prime extension into these cells.
:param cons: Buffer of integrated conservative variables
:type cons: jnp.ndarray
:param levelset_new: Integrated levelset buffer
:type levelset_new: jnp.ndarray
:param volume_fraction_new: Integrated volume fraction buffer
:type volume_fraction_new: jnp.ndarray
:param volume_fraction_old: Volume fraction buffer of previous time step (RK stage)
:type volume_fraction_old: jnp.ndarray
:return: Tuple containing the mixed conservatives and a mask indicating small cells with negative energy/mass.
:rtype: Tuple[jnp.ndarray, jnp.ndarray]
"""
# GEOMETRICAL QUANTITIES
normal = self.geometry_calculator.compute_normal(levelset_new)
if self.levelset_type == "FLUID-FLUID":
normal = jnp.stack([normal, -normal], axis=1)
volume_fraction_new = jnp.stack([volume_fraction_new, 1.0 - volume_fraction_new], axis=0)
volume_fraction_old = jnp.stack([volume_fraction_old, 1.0 - volume_fraction_old], axis=0)
normal_sign = jnp.sign(normal)
# CREATE SOURCE MASK - SMALL CELLS, NEWLY CREATED, AND VANISHED CELLS
small_cells = jnp.where((volume_fraction_new < self.volume_fraction_threshold) & (volume_fraction_new > 0.0), 1, 0)
newly_created_cells = jnp.where((volume_fraction_new > 0.0) & (volume_fraction_old == 0.0), 1, 0)
vanished_cells = jnp.where((volume_fraction_new == 0.0) & (volume_fraction_old > 0.0), 1, 0)
mask_source = jnp.maximum(small_cells, jnp.maximum(vanished_cells, newly_created_cells))
# MIXING WEIGHTS II
mixing_weight_sum = 0.0
mixing_weight_ii_list = []
for i in self.active_axis_indices:
mixing_weight_ii = jnp.square(normal[i])
volume_fraction_target_at_source = move_target_to_source_ii(volume_fraction_new, normal_sign, i)
mixing_weight_ii *= volume_fraction_target_at_source
mixing_weight_sum += mixing_weight_ii
mixing_weight_ii_list.append( mixing_weight_ii )
mixing_weight_ii = jnp.stack(mixing_weight_ii_list, axis=0)
# MIXING WEIGHTS IJ
if self.dim > 1 and self.mixing_targets > 1:
mixing_weight_ij_list = []
for (i, j) in self.index_pairs_mixing:
mixing_weight_ij = jnp.abs(normal[i]*normal[j])
volume_fraction_target_at_source = move_target_to_source_ij(volume_fraction_new, normal_sign, i, j)
mixing_weight_ij *= volume_fraction_target_at_source
mixing_weight_sum += mixing_weight_ij
mixing_weight_ij_list.append( mixing_weight_ij )
mixing_weight_ij = jnp.stack(mixing_weight_ij_list, axis=0)
# MIXING WEIGHTS IJK
if self.dim == 3 and self.mixing_targets == 3:
mixing_weight_ijk = jnp.abs(normal[0]*normal[1]*normal[2])**(2/3)
volume_fraction_target_at_source = move_target_to_source_ijk(volume_fraction_new, normal_sign)
mixing_weight_ijk *= volume_fraction_target_at_source
mixing_weight_sum += mixing_weight_ijk
# NORMALIZATION
mixing_weight_ii = mixing_weight_ii/(mixing_weight_sum + self.eps)
mixing_weight_ij = mixing_weight_ij/(mixing_weight_sum + self.eps) if self.dim > 1 and self.mixing_targets > 1 else None
mixing_weight_ijk = mixing_weight_ijk/(mixing_weight_sum + self.eps) if self.dim == 3 and self.mixing_targets == 3 else None
Mixing_fluxes_source = []
Mixing_fluxes_target = []
# MIXING CONTRIBUTIONS II
for i, axis in enumerate(self.active_axis_indices):
# MOVE TARGET VALUES TO SOURCE POSITION
volume_fraction_target_at_source = move_target_to_source_ii(volume_fraction_new, normal_sign, axis)
cons_target_at_source = move_target_to_source_ii(cons[...,self.nhx__,self.nhy__,self.nhz__], normal_sign, axis)
# MIXING FLUXES
M_xi_source = mixing_weight_ii[i]/(volume_fraction_new * mixing_weight_ii[i] + volume_fraction_target_at_source + self.eps) * (cons_target_at_source * volume_fraction_new - cons[...,self.nhx__,self.nhy__,self.nhz__] * volume_fraction_target_at_source) * mask_source
M_xi_target = move_source_to_target_ii(-M_xi_source, normal_sign, axis)
Mixing_fluxes_source.append(M_xi_source)
Mixing_fluxes_target.append(M_xi_target)
# MIXING CONTRIBUTIONS IJ
if self.dim > 1 and self.mixing_targets > 1:
for k, (axis_i, axis_j) in enumerate(self.index_pairs_mixing):
# MOVE TARGET VALUES TO SOURCE POSITION
volume_fraction_target_at_source = move_target_to_source_ij(volume_fraction_new, normal_sign, axis_i, axis_j)
cons_target_at_source = move_target_to_source_ij(cons[...,self.nhx__,self.nhy__,self.nhz__], normal_sign, axis_i, axis_j)
# MIXING FLUXES
M_xi_source = mixing_weight_ij[k]/(volume_fraction_new * mixing_weight_ij[k] + volume_fraction_target_at_source + self.eps) * (cons_target_at_source * volume_fraction_new - cons[...,self.nhx__,self.nhy__,self.nhz__] * volume_fraction_target_at_source) * mask_source
M_xi_target = move_source_to_target_ij(-M_xi_source, normal_sign, axis_i, axis_j)
Mixing_fluxes_source.append(M_xi_source)
Mixing_fluxes_target.append(M_xi_target)
# MIXING CONTRIBUTIONS IJK
if self.dim == 3 and self.mixing_targets == 3:
volume_fraction_target_at_source = move_target_to_source_ijk(volume_fraction_new, normal_sign)
cons_target_at_source = move_source_to_target_ijk(cons[...,self.nhx__,self.nhy__,self.nhz__], normal_sign)
# MIXING FLUXES
M_xi_source = mixing_weight_ijk/(volume_fraction_new * mixing_weight_ijk + volume_fraction_target_at_source + self.eps) * (cons_target_at_source * volume_fraction_new - cons[...,self.nhx__,self.nhy__,self.nhz__] * volume_fraction_target_at_source) * mask_source
M_xi_target = move_source_to_target_ijk(-M_xi_source, normal_sign)
Mixing_fluxes_source.append(M_xi_source)
Mixing_fluxes_target.append(M_xi_target)
cons = cons.at[...,self.nhx,self.nhy,self.nhz].add(sum(Mixing_fluxes_source)[...,self.nhx_,self.nhy_,self.nhz_] + sum(Mixing_fluxes_target)[...,self.nhx_,self.nhy_,self.nhz_])
mask_small_cells = jnp.where((cons[0,...,self.nhx,self.nhy,self.nhz] < 0.0) | (cons[4,...,self.nhx,self.nhy,self.nhz] < 0.0), 1, 0) * jnp.where((volume_fraction_new[...,self.nhx_,self.nhy_,self.nhz_] != 0.0), 1, 0)
return cons, mask_small_cells
[docs]
def compute_primitives_from_conservatives_in_real_fluid(self, cons: jnp.ndarray, primes: jnp.ndarray,
levelset: jnp.ndarray, volume_fraction: jnp.ndarray, mask_small_cells: jnp.ndarray) -> jnp.ndarray:
"""Computes the primitive variables from the mixed conservatives within the real fluid.
:param cons: Buffer of primitive variables
:type cons: jnp.ndarray
:param primes: _description_
:type primes: jnp.ndarray
:param levelset: _description_
:type levelset: jnp.ndarray
:param volume_fraction: _description_
:type volume_fraction: jnp.ndarray
:param mask_small_cells: _description_
:type mask_small_cells: jnp.ndarray
:return: _description_
:rtype: jnp.ndarray
"""
# WE COMPUTE PRIMES ONLY IN REAL CELLS IN WHICH WE DO NOT EXTEND, I.E. REAL CELLS MINUS SMALL CELLS
mask_real, _ = self.compute_masks(levelset, volume_fraction)
mask_real = mask_real[...,self.nhx_,self.nhy_,self.nhz_] * (1 - mask_small_cells)
mask_ghost = 1 - mask_real
primes_in_real = get_primitives_from_conservatives(cons[...,self.nhx,self.nhy,self.nhz], self.material_manager) * mask_real
primes = primes.at[...,self.nhx,self.nhy,self.nhz].mul(mask_ghost)
primes = primes.at[...,self.nhx,self.nhy,self.nhz].add(primes_in_real)
return primes