Source code for jaxfluids.forcing.forcing

#*------------------------------------------------------------------------------*
#* JAX-FLUIDS -                                                                 *
#*                                                                              *
#* A fully-differentiable CFD solver for compressible two-phase flows.          *
#* Copyright (C) 2022  Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams    *
#*                                                                              *
#* This program is free software: you can redistribute it and/or modify         *
#* it under the terms of the GNU General Public License as published by         *
#* the Free Software Foundation, either version 3 of the License, or            *
#* (at your option) any later version.                                          *
#*                                                                              *
#* This program is distributed in the hope that it will be useful,              *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of               *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                *
#* GNU General Public License for more details.                                 *
#*                                                                              *
#* You should have received a copy of the GNU General Public License            *
#* along with this program.  If not, see <https://www.gnu.org/licenses/>.       *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* CONTACT                                                                      *
#*                                                                              *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de        *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* Munich, April 15th, 2022                                                     *
#*                                                                              *
#*------------------------------------------------------------------------------*

from functools import partial
from typing import Tuple, Union, Dict
import types

import jax
import jax.numpy as jnp
from jaxfluids import levelset

from jaxfluids.forcing.pid_control import PIDControl
from jaxfluids.domain_information import DomainInformation
from jaxfluids.materials.material_manager import MaterialManager
from jaxfluids.unit_handler import UnitHandler
from jaxfluids.io_utils.logger import Logger
from jaxfluids.levelset.levelset_handler import LevelsetHandler
from jaxfluids.turb.turb_stats_manager import TurbStatsManager

[docs] class Forcing: """Class that manages the computation of external forcing terms. Currently implemented are: 1) Mass flow rate forcing 2) Temperature forcing 3) Homogeneous isotropic turbulence forcing """ def __init__(self, domain_information: DomainInformation, material_manager: MaterialManager, unit_handler: UnitHandler, levelset_handler: Union[LevelsetHandler, None], levelset_type: str, is_mass_flow_forcing: bool, is_temperature_forcing: bool, is_turb_hit_forcing: bool, mass_flow_target: Union[float, types.LambdaType], flow_direction: str, temperature_target: Union[float, types.LambdaType]) -> None: # BOOLS FORCINGS self.is_mass_flow_forcing = is_mass_flow_forcing self.is_temperature_forcing = is_temperature_forcing self.is_turb_hit_forcing = is_turb_hit_forcing # MATERIAL AND UNIT HANDLER self.material_manager = material_manager self.unit_handler = unit_handler self.levelset_handler = levelset_handler if is_turb_hit_forcing: self.turb_stats_manager = TurbStatsManager(domain_information, material_manager) self.k_mag_vec = self.turb_stats_manager.k_mag_vec self.k_field = self.turb_stats_manager.k_field self.one_k2_field = self.turb_stats_manager.one_k2_field self.shell = self.turb_stats_manager.shell # LEVELSET INTERFACE INTERACTION TYPE self.levelset_type = levelset_type # DOMAIN INFORMATION self.inactive_axis_indices = [{"x": 0, "y": 1, "z": 2}[axis] for axis in domain_information.inactive_axis] self.nhx, self.nhy, self.nhz = domain_information.domain_slices_conservatives self.nhx_, self.nhy_, self.nhz_ = domain_information.domain_slices_geometry self.nx, self.ny, self.nz = domain_information.number_of_cells dx, dy, dz = domain_information.cell_sizes self.dim = domain_information.dim self.cell_centers = domain_information.cell_centers self.active_axis_indices = domain_information.active_axis_indices # BOOLS FORCINGS self.is_mass_flow_forcing = is_mass_flow_forcing self.is_temperature_forcing = is_temperature_forcing self.is_turb_hit_forcing = is_turb_hit_forcing # MASS FLOW FORCING self.mass_flow_target = mass_flow_target self.PID_mass_flow_forcing = PIDControl(K_P = 5e-1, K_I = 5, K_D = 0, T_N = 5, T_V = 1) if flow_direction == "x": self.vec = jnp.array([1.0, 0.0, 0.0]) self.int_ax = (-1,-2) self.index = 1 self.dA = dy * dz elif flow_direction == "y": self.vec = jnp.array([0.0, 1.0, 0.0]) self.int_ax = (-3,-1) self.index = 2 self.dA = dx * dz elif flow_direction == "z": self.vec = jnp.array([0.0, 0.0, 1.0]) self.int_ax = (-3,-2) self.index = 3 self.dA = dx * dy # TEMPERATURE FORCING self.temperature_target = temperature_target
[docs] def compute_forcings(self, primes: jnp.ndarray, cons: jnp.ndarray, levelset: Union[jnp.ndarray, None], volume_fraction: Union[jnp.ndarray, None], current_time: float, timestep_size: float, PID_e_new: float, PID_e_int: float, logger: Logger, primes_dash: Union[jnp.ndarray, None] = None, **kwargs) -> Dict: """Computes forcings for temperature, mass flow and turbulence kinetic energy. :param primes: buffer of primitive variables :type primes: jnp.ndarray :param primes_dash: buffer of primitive variables for next time step without forcing :type primes_dash: Union[jnp.ndarray, None] :param cons: buffer of conservative variables :type cons: jnp.ndarray :param volume_fraction: buffer of volume fractions :type volume_fraction: Union[jnp.ndarray, None] :param mask_real: mask indicating the real fluid :type mask_real: Union[jnp.ndarray, None] :param current_time: current physical simulation time :type current_time: float :param timestep_size: current physical time step size :type timestep_size: float :param PID_e_new: Error of previous timestep for PID controller :type PID_e_new: float :param PID_e_int: Accumalated error for PID controller :type PID_e_int: float :param logger: Logger for terminal output :type logger: Logger :return: Dictionary containing buffers of forcings :rtype: Dict """ forcings_dictionary = {} if self.is_mass_flow_forcing: mass_flow_forcing, mass_flow_current, mass_flow_target, PID_e_new, PID_e_int = self.compute_mass_flow_forcing(cons, primes, volume_fraction, current_time, timestep_size, PID_e_new, PID_e_int) forcings_dictionary.update({ "mass_flow": { "force": mass_flow_forcing, "PID_e_new": PID_e_new, "PID_e_int": PID_e_int }, }) logger.log_start_time_step([ 'PID CONTROL', 'MASS FLOW TARGET = %4.4e' %(mass_flow_target), 'MASS FLOW CURRENT = %4.4e' %(mass_flow_current), ]) if self.is_temperature_forcing: temperature_forcing, temperature_error = self.compute_temperature_forcing(primes, levelset, volume_fraction, current_time, timestep_size) forcings_dictionary.update({ "temperature": { "force": temperature_forcing, }, }) logger.log_start_time_step([ 'TEMPERATURE CONTROL', 'TEMPERATURE ERROR = %4.4e' % temperature_error, ]) if self.is_turb_hit_forcing: turb_hit_forcing = self.compute_turb_hit_forcing(primes, primes_dash, timestep_size) forcings_dictionary.update({ "turbulence": { "force": turb_hit_forcing } }) return forcings_dictionary
[docs] @partial(jax.jit, static_argnums=(0)) def compute_temperature_forcing(self, primes: jnp.ndarray, levelset: Union[jnp.ndarray, None], volume_fraction: Union[jnp.ndarray, None], current_time: float, timestep_size: float) -> Tuple[jnp.ndarray, float]: """Computes temperature forcing. :param primes: Buffer of primitive variables. :type primes: jnp.ndarray :param levelset: Buffer of level-set field. :type levelset: Union[jnp.ndarray, None] :param volume_fraction: Buffer of volume fraction field. :type volume_fraction: Union[jnp.ndarray, None] :param current_time: Current simulation time. :type current_time: float :param timestep_size: Current integration time step. :type timestep_size: float :return: Buffer of the forcing vector and the mean absolute error wrt the temperature target. :rtype: Tuple[jnp.ndarray, float] """ # COMPUTE TEMPERATURE temperature = self.material_manager.get_temperature(primes[4,...,self.nhx,self.nhy,self.nhz], primes[0,...,self.nhx,self.nhy,self.nhz]) # COMPUTE LAMBDA INPUTS mesh_grid = [jnp.meshgrid(*self.cell_centers, indexing="ij")[i] for i in self.active_axis_indices] for i in range(len(mesh_grid)): mesh_grid[i] = self.unit_handler.dimensionalize(mesh_grid[i], "length") current_time = self.unit_handler.dimensionalize(current_time, "time") # COMPUTE TEMPERATURE TARGET if type(self.temperature_target) == types.LambdaType: temperature_target = self.temperature_target(*mesh_grid, current_time) for axis in self.inactive_axis_indices: temperature_target = jnp.expand_dims(temperature_target, axis) else: temperature_target = self.temperature_target temperature_target = self.unit_handler.non_dimensionalize(temperature_target, "temperature") # COMPUTE REAL FLUID MASK if self.levelset_type != None: mask_real, _ = self.levelset_handler.compute_masks(levelset, volume_fraction) # COMPUTE TEMPERATURE FORCING R, gamma, rho = self.material_manager.R, self.material_manager.gamma, primes[0,...,self.nhx,self.nhy,self.nhz] temperature_error = (temperature_target - temperature) * mask_real[...,self.nhx_,self.nhy_,self.nhz_] if self.levelset_type != None else temperature_target - temperature forcing = rho * R * gamma/(gamma - 1) * (temperature_error) / timestep_size mean_absolute_error = jnp.mean(jnp.abs(temperature_error)) forcing = [jnp.zeros_like(forcing) for i in range(4)] + [forcing] return jnp.stack(forcing, axis=0), mean_absolute_error
[docs] @partial(jax.jit, static_argnums=(0)) def compute_mass_flow_forcing(self, cons: jnp.ndarray, primes: jnp.ndarray, volume_fraction: Union[jnp.ndarray, None], current_time: float, timestep_size: float, PID_e_new: float, PID_e_int: float) -> Tuple[jnp.ndarray, float, float, float, float]: """Computes mass flow forcing :param cons: Buffer of the conservative variables. :type cons: jnp.ndarray :param primes: Buffer of the primitive variables. :type primes: jnp.ndarray :param volume_fraction: Buffer of the volume fraction in two-phase flows. :type volume_fraction: Union[jnp.ndarray, None] :param current_time: Current simulation time. :type current_time: float :param timestep_size: Current time step. :type timestep_size: float :param PID_e_new: Current PID error :type PID_e_new: float :param PID_e_int: Current PID integral error :type PID_e_int: float :return: Buffer of the body force, current mass flow, mass flow target, PID error, PID integral error :rtype: Tuple[jnp.ndarray, float, float] """ # COMPUTE MASS FLOW TARGET if type(self.mass_flow_target) == types.LambdaType: mass_flow_target = self.mass_flow_target(self.unit_handler.dimensionalize(current_time, "time")) else: mass_flow_target = self.mass_flow_target mass_flow_target = self.unit_handler.non_dimensionalize(mass_flow_target, "mass_flow") # COMPUTE CURRENT MASS FLOW momentum = cons[self.index, ..., self.nhx, self.nhy, self.nhz] * volume_fraction[...,self.nhx_,self.nhy_,self.nhz_] if self.levelset_type != None else cons[self.index, ..., self.nhx, self.nhy, self.nhz] mass_flow_current = jnp.mean(jnp.sum(self.dA * momentum, axis=self.int_ax), axis=-1) mass_flow_current = jnp.sum(mass_flow_current) if self.levelset_type == "FLUID-FLUID" else mass_flow_current # COMPUTE MASS FLOW FORCING mass_flow_forcing_scalar, PID_e_new, PID_e_int = self.PID_mass_flow_forcing.compute_output(mass_flow_current, mass_flow_target, timestep_size, PID_e_new, PID_e_int) mass_flow_forcing = mass_flow_forcing_scalar * self.vec density = primes[0:1,...,self.nhx,self.nhy,self.nhz] vels = primes[1:4,...,self.nhx,self.nhy,self.nhz] body_force_momentum = jnp.einsum("ij..., jk...->ik...", mass_flow_forcing.reshape(3,1), jnp.ones(density.shape)) body_force_energy = jnp.einsum("ij..., jk...->ik...", mass_flow_forcing.reshape(1,3), vels) body_force = jnp.vstack([jnp.zeros(body_force_energy.shape), body_force_momentum, body_force_energy]) return body_force, mass_flow_current, mass_flow_target, PID_e_new, PID_e_int
[docs] @partial(jax.jit, static_argnums=(0)) def compute_turb_hit_forcing(self, primes: jnp.ndarray, primes_dash: jnp.ndarray, timestep: float) -> jnp.ndarray: """Computes forcing for HIT :param primes: Buffer of primitive variables. :type primes: jnp.ndarray :param primes_dash: Buffer of intermediate primitive variables which are obtained by integrating primes without forcing term. :type primes_dash: jnp.ndarray :param timestep: Current time step. :type timestep: float :return: Buffer of the forcing vector. :rtype: jnp.ndarray """ primes = primes[:,self.nhx,self.nhy,self.nhz] primes_dash = primes_dash[:,self.nhx,self.nhy,self.nhz] # TODO MAKE eta_s user-specified parameter eta_s = 2 Tbar = jnp.mean(self.material_manager.get_temperature(p=primes[4], rho=primes[0])) s_0 = jnp.zeros(primes[0].shape) s_1, s_2, s_3 = self.calculate_velocity_forcing_vector(primes[1:4], primes_dash[1:4], eta_s, timestep) s_4 = primes[1] * s_1 + primes[2] * s_2 + primes[3] * s_3 + (self.temperature_target - Tbar) * self.material_manager.R / (self.material_manager.gamma - 1) force = [s_0, s_1, s_2, s_3, s_4] return primes[0] * jnp.stack(force)
[docs] def calculate_velocity_forcing_vector(self, vels: jnp.ndarray, vels_dash: jnp.ndarray, eta_s: int, timestep: float) -> jnp.ndarray: """Calculates the velocity forcing vector for HIT forcing. :param vels: Buffer of velocities. :type vels: jnp.ndarray :param vels_dash: Buffer of intermediate velocities which are obtained by integrating primes without forcing term. :type vels_dash: jnp.ndarray :param eta_s: Cut-off wavenumber up to which forcing is applied. :type eta_s: int :param timestep: Current time step. :type timestep: float :return: Buffer of the velocity forcing vector. :rtype: jnp.ndarray """ vels_hat = jnp.stack([jnp.fft.rfftn(vels[ii], axes=(2,1,0)) for ii in range(3)]) vels_dash_hat = jnp.stack([jnp.fft.rfftn(vels_dash[ii], axes=(2,1,0)) for ii in range(3)]) ek = self.turb_stats_manager.energy_spectrum_spectral(vels_hat) ek_dash = self.turb_stats_manager.energy_spectrum_spectral(vels_dash_hat) Cs_eta = 0.5 / (ek_dash + 1e-10) * (ek_dash - ek) / timestep * (self.k_mag_vec <= eta_s) div_u = self.k_field[0] * vels_hat[0] + self.k_field[1] * vels_hat[1] + self.k_field[2] * vels_hat[2] Cs = Cs_eta[self.shell] s_hat = [ -Cs * (vels_hat[0] - self.k_field[0] * self.one_k2_field * div_u), -Cs * (vels_hat[1] - self.k_field[1] * self.one_k2_field * div_u), -Cs * (vels_hat[2] - self.k_field[2] * self.one_k2_field * div_u), ] return jnp.stack([jnp.fft.irfftn(s_hat[ii], axes=(2,1,0)) for ii in range(3)])