Source code for jaxfluids.solvers.source_term_solver

#*------------------------------------------------------------------------------*
#* JAX-FLUIDS -                                                                 *
#*                                                                              *
#* A fully-differentiable CFD solver for compressible two-phase flows.          *
#* Copyright (C) 2022  Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams    *
#*                                                                              *
#* This program is free software: you can redistribute it and/or modify         *
#* it under the terms of the GNU General Public License as published by         *
#* the Free Software Foundation, either version 3 of the License, or            *
#* (at your option) any later version.                                          *
#*                                                                              *
#* This program is distributed in the hope that it will be useful,              *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of               *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                *
#* GNU General Public License for more details.                                 *
#*                                                                              *
#* You should have received a copy of the GNU General Public License            *
#* along with this program.  If not, see <https://www.gnu.org/licenses/>.       *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* CONTACT                                                                      *
#*                                                                              *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de        *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* Munich, April 15th, 2022                                                     *
#*                                                                              *
#*------------------------------------------------------------------------------*

from functools import partial

import jax
import jax.numpy as jnp

from jaxfluids.domain_information import DomainInformation
from jaxfluids.materials.material_manager import MaterialManager
from jaxfluids.stencils.spatial_derivative import SpatialDerivative
from jaxfluids.stencils.spatial_reconstruction import SpatialReconstruction

[docs] class SourceTermSolver: """ The SourceTermSolver class manages source term contributions on the right-hand-side of NSE. Contributions have to be activated in the active_physics section in the numerical setup file. This includes amongst others: 1) Gravitational force 2) Heat flux 3) Viscous flux """ def __init__(self, material_manager: MaterialManager, gravity: jnp.ndarray, domain_information: DomainInformation, derivative_stencil_center: SpatialDerivative, derivative_stencil_face: SpatialDerivative, reconstruct_stencil_duidxi: SpatialReconstruction, reconstruct_stencil_ui: SpatialReconstruction, levelset_type: str) -> None: self.gravity = gravity self.material_manager = material_manager self.derivative_stencil_center = derivative_stencil_center self.derivative_stencil_face = derivative_stencil_face self.reconstruct_stencil_duidxi = reconstruct_stencil_duidxi self.reconstruct_stencil_ui = reconstruct_stencil_ui self.active_axis_indices = [{"x": 0, "y": 1, "z": 2}[axis] for axis in domain_information.active_axis] nx, ny, nz = domain_information.number_of_cells if levelset_type == "FLUID-FLUID": self.shape_fluxes = [ (3, 2, nx+1, ny, nz), (3, 2, nx, ny+1, nz), (3, 2, nx, ny, nz+1) ] else: self.shape_fluxes = [ (3, nx+1, ny, nz), (3, nx, ny+1, nz), (3, nx, ny, nz+1) ] self.cell_sizes = domain_information.cell_sizes self.nhx, self.nhy, self.nhz = domain_information.domain_slices_conservatives
[docs] def compute_gravity_forces(self, cons: jnp.ndarray) -> jnp.ndarray: """Computes flux due to gravitational force. :param cons: Buffer of conservative variables. :type cons: jnp.ndarray :return: Buffer with gravitational forces. :rtype: jnp.ndarray """ density = cons[0:1, ..., self.nhx, self.nhy, self.nhz] momentum = cons[1:4, ..., self.nhx, self.nhy, self.nhz] gravity_momentum = jnp.einsum("ij..., jk...->ik...", self.gravity.reshape(3,1), density) gravity_energy = jnp.einsum("ij..., jk...->ik...", self.gravity.reshape(1,3), momentum) gravity_forces = jnp.vstack([jnp.zeros(gravity_energy.shape), gravity_momentum, gravity_energy]) return gravity_forces
[docs] def compute_heat_flux_xi(self, temperature: jnp.ndarray, axis: int) -> jnp.ndarray: """Computes the heat flux in axis direction. :param temperature: Buffer with temperature. :type temperature: jnp.ndarray :param axis: Spatial direction along which the heat flux is calculated. :type axis: int :return: Heat flux in axis direction. :rtype: jnp.ndarray """ temperature_at_xj = self.reconstruct_stencil_ui.reconstruct_xi(temperature, axis) thermal_conductivity = self.material_manager.get_thermal_conductivity(temperature_at_xj) temperature_grad = self.derivative_stencil_face.derivative_xi(temperature, self.cell_sizes[axis], axis) flux_xi = -thermal_conductivity*temperature_grad heat_flux_xi = jnp.vstack([jnp.zeros(flux_xi.shape), jnp.zeros(flux_xi.shape), jnp.zeros(flux_xi.shape), jnp.zeros(flux_xi.shape), flux_xi]) return heat_flux_xi
[docs] def compute_viscous_flux_xi(self, vels: jnp.ndarray, temperature: jnp.ndarray, axis: int) -> jnp.ndarray: """Computes viscous flux in one spatial direction vel_grad = [ du/dx du/dy du/dz dv/dx dv/dy dv/dz dw/dx dw/dy dw/dz ] :param vels: Buffer of velocities. :type vels: jnp.ndarray :param temperature: Buffer of temperature. :type temperature: jnp.ndarray :param axis: Axis along which the viscous flux is computed. :type axis: int :return: Viscous flux along axis direction. :rtype: jnp.ndarray """ temperature_at_cf = self.reconstruct_stencil_ui.reconstruct_xi(temperature, axis) dynamic_viscosity = self.material_manager.get_dynamic_viscosity(temperature_at_cf[0]) bulk_viscosity = self.material_manager.bulk_viscosity vel_grad = jnp.stack([self.compute_xi_derivatives_at_xj(vels, i, axis) if i in self.active_axis_indices else jnp.zeros(self.shape_fluxes[axis]) for i in range(3)], axis=1) tau_j = self.compute_tau(vel_grad, dynamic_viscosity, bulk_viscosity, axis) vel_at_xj = self.reconstruct_stencil_ui.reconstruct_xi(vels, axis) vel_tau_at_xj = 0 for k in self.active_axis_indices: vel_tau_at_xj += tau_j[k] * vel_at_xj[k] fluxes_xj = jnp.stack([jnp.zeros(tau_j[0].shape), tau_j[0], tau_j[1], tau_j[2], vel_tau_at_xj]) return fluxes_xj
[docs] def compute_tau(self, vel_grad: jnp.ndarray, dynamic_viscosity: jnp.ndarray, bulk_viscosity: jnp.ndarray, axis: int) -> jnp.ndarray: """Computes the stress tensor at a cell face in axis direction. tau_axis = [tau_axis0, tau_axis1, tau_axis2] :param vel_grad: Buffer of velocity gradient. Shape is 3 x 3 (x 2) x Nx x Ny x Nz :type vel_grad: jnp.ndarray :param dynamic_viscosity: Buffer of dynamic viscosity. :type dynamic_viscosity: jnp.ndarray :param bulk_viscosity: Buffer of bulk viscosity. :type bulk_viscosity: jnp.ndarray :param axis: Cell face direction at which viscous stresses are calculated. :type axis: int :return: Buffer of viscous stresses. :rtype: jnp.ndarray """ mu_1 = dynamic_viscosity mu_2 = bulk_viscosity - 2.0 / 3.0 * dynamic_viscosity tau_list = [ mu_1 * (vel_grad[axis,0] + vel_grad[0,axis]) if axis in self.active_axis_indices and 0 in self.active_axis_indices else jnp.zeros(vel_grad.shape[2:]), mu_1 * (vel_grad[axis,1] + vel_grad[1,axis]) if axis in self.active_axis_indices and 1 in self.active_axis_indices else jnp.zeros(vel_grad.shape[2:]), mu_1 * (vel_grad[axis,2] + vel_grad[2,axis]) if axis in self.active_axis_indices and 2 in self.active_axis_indices else jnp.zeros(vel_grad.shape[2:]) ] tau_list[axis] += mu_2 * sum([vel_grad[k,k] for k in self.active_axis_indices]) return jnp.stack(tau_list)
[docs] def compute_xi_derivatives_at_xj(self, prime: jnp.ndarray, axis_i: int, axis_j: int) -> jnp.ndarray: """Computes the spatial derivative in axis_i direction at the cell face in axis_j direction. :param prime: Buffer of primitive variables. :type prime: jnp.ndarray :param axis_i: Spatiald direction wrt which the derivative is taken. :type axis_i: int :param axis_j: Spatial direction along which derivative is evaluated. :type axis_j: int :return: Derivative wrt axis_i direction at cell face in axis_j direction. :rtype: jnp.ndarray """ if axis_i == axis_j: # Direct derivative evaluation at cell face deriv_xi_at_xj = self.derivative_stencil_face.derivative_xi(prime, self.cell_sizes[axis_i], axis_i) else: # 1) Derivative at cell center 2) Interpolate to cell face deriv_xi_at_c = self.derivative_stencil_center.derivative_xi(prime, self.cell_sizes[axis_i], axis_i) deriv_xi_at_xj = self.reconstruct_stencil_duidxi.reconstruct_xi(deriv_xi_at_c, axis_j) return deriv_xi_at_xj