Source code for jaxfluids.io_utils.output_writer

#*------------------------------------------------------------------------------*
#* JAX-FLUIDS -                                                                 *
#*                                                                              *
#* A fully-differentiable CFD solver for compressible two-phase flows.          *
#* Copyright (C) 2022  Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams    *
#*                                                                              *
#* This program is free software: you can redistribute it and/or modify         *
#* it under the terms of the GNU General Public License as published by         *
#* the Free Software Foundation, either version 3 of the License, or            *
#* (at your option) any later version.                                          *
#*                                                                              *
#* This program is distributed in the hope that it will be useful,              *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of               *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                *
#* GNU General Public License for more details.                                 *
#*                                                                              *
#* You should have received a copy of the GNU General Public License            *
#* along with this program.  If not, see <https://www.gnu.org/licenses/>.       *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* CONTACT                                                                      *
#*                                                                              *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de        *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* Munich, April 15th, 2022                                                     *
#*                                                                              *
#*------------------------------------------------------------------------------*

import os
import json
from typing import Dict, Tuple, Union
from functools import partial

import h5py
import jax
import jax.numpy as jnp

from jaxfluids.domain_information import DomainInformation
from jaxfluids.levelset.levelset_handler import LevelsetHandler
from jaxfluids.materials.material_manager import MaterialManager
from jaxfluids.unit_handler import UnitHandler
from jaxfluids.input_reader import InputReader
from jaxfluids.stencils.spatial_derivative import SpatialDerivative

[docs] class OutputWriter: """Output writer for JAX-FLUIDS. The OutputWriter class can write h5 and xdmf files. h5 and xdmf files can be visualized in paraview. Xdmf output is activated via the is_xdmf_file flag under the output keyword in the numerical setup. If the xdmf option is activated, a single xdmf file is written for each h5 file. Additionally, at the end of the simulation, an time series xdmf file is written which summarizes the whole simulation. This enables loading the entire timeseries into Paraview. """ def __init__(self, input_reader: InputReader, unit_handler: UnitHandler, domain_information: DomainInformation, material_manager: MaterialManager, levelset_handler: LevelsetHandler, derivative_stencil_conservatives: SpatialDerivative, derivative_stencil_geometry : Union[SpatialDerivative, None]) -> None: # GENERAL self.case_name = input_reader.case_name self.next_timestamp = unit_handler.non_dimensionalize(input_reader.save_dt, "time") self.save_dt = unit_handler.non_dimensionalize(input_reader.save_dt, "time") self.save_path = input_reader.save_path # JSON self.case_setup = input_reader.case_setup self.numerical_setup = input_reader.numerical_setup # NUMERICAL SETUP self.is_xdmf = self.numerical_setup["output"]["is_xdmf"] self.is_double = self.numerical_setup["output"]["is_double_precision_output"] self.output_quantities = self.numerical_setup["output"]["quantities"] self.is_mass_flow_forcing = self.numerical_setup["active_forcings"]["is_mass_flow_forcing"] self.levelset_type = input_reader.levelset_type self.derivative_stencil_conservatives = derivative_stencil_conservatives self.derivative_stencil_geometry = derivative_stencil_geometry # MATERIAL, UNITHANDLER self.material_manager = material_manager self.unit_handler = unit_handler self.levelset_handler = levelset_handler # DOMAIN INFORMATION self.number_of_cells = domain_information.number_of_cells self.cell_centers = domain_information.cell_centers self.cell_faces = domain_information.cell_faces self.cell_sizes = domain_information.cell_sizes self.nhx, self.nhy, self.nhz = domain_information.domain_slices_conservatives self.nhx_, self.nhy_, self.nhz_ = domain_information.domain_slices_geometry self.nhx__, self.nhy__, self.nhz__ = domain_information.domain_slices_conservatives_to_geometry self.active_axis_indices = domain_information.active_axis_indices # QUANTITIES self.quantity_index = { "primes": {"density": 0, "velocityX": 1, "velocityY": 2, "velocityZ": 3, "velocity": jnp.s_[1:4], "pressure": 4, "temperature": 5}, "cons": {"mass": 0, "momentumX": 1, "momentumY": 2, "momentumZ": 3, "momentum": jnp.s_[1:4], "energy": 4} } self.output_timeseries = [] self.xdmf_timeseries = [] self.save_path_case, self.save_path_domain = self.get_folder_name()
[docs] def create_folder(self) -> None: """Sets up a folder for the simulation. Dumps the numerical setup and cas setup into the simulation folder and creates an output folder within in simulation folder into which simulation output is saved. simulation_folder ---- Numerical setup ---- Case setup ---- domain """ os.mkdir(self.save_path_case) os.mkdir(self.save_path_domain) with open(os.path.join(self.save_path_case, self.case_name + ".json"), "w") as json_file: json.dump(self.case_setup, json_file, ensure_ascii=False, indent=4) with open(os.path.join(self.save_path_case, "numerical_setup.json"), "w") as json_file: json.dump(self.numerical_setup, json_file, ensure_ascii=False, indent=4)
[docs] def get_folder_name(self) -> Tuple[str, str]: """Returns a name for the simulation folder based on the case name. :return: Path to the simulation folder and path to domain folder within simulation folder. :rtype: Tuple[str, str] """ case_name_folder = self.case_name if not os.path.exists(self.save_path): os.mkdir(self.save_path) create_directory = True i = 1 while create_directory: if os.path.exists(os.path.join(self.save_path, case_name_folder)): case_name_folder = self.case_name + "-%d" % i i += 1 else: save_path_case = os.path.join(self.save_path, case_name_folder) save_path_domain = os.path.join(save_path_case, "domain") create_directory = False return save_path_case, save_path_domain
[docs] def write_output(self, buffer_dictionary: Dict[str, Dict[str, Union[jnp.ndarray, float]]], force_output: bool, simulation_finish: bool = False) -> None: """Writes h5 and (optional) xdmf output. :param buffer_dictionary: Dictionary with flow field buffers :type buffer_dictionary: Dict[str, Dict[str, Union[jnp.ndarray, float]]] :param force_output: Flag which forces an output. :type force_output: bool :param simulation_finish: Flag that indicates the simulation finish -> then timeseries xdmf is written, defaults to False :type simulation_finish: bool, optional """ current_time = buffer_dictionary["time_control"]["current_time"] if force_output: self.write_h5file(buffer_dictionary) if self.is_xdmf: self.write_xdmffile(current_time) else: diff = current_time - self.next_timestamp if diff >= -jnp.finfo(jnp.float64).eps: self.write_h5file(buffer_dictionary) if self.is_xdmf: self.write_xdmffile(current_time) self.next_timestamp += self.save_dt if simulation_finish and self.is_xdmf: self.write_timeseries_xdmffile()
[docs] def write_h5file(self, buffer_dictionary: Dict[str, Dict[str, jnp.ndarray]]) -> None: current_time = buffer_dictionary["time_control"]["current_time"] filename = "data_%.6f.h5" % self.unit_handler.dimensionalize(current_time, "time") with h5py.File(os.path.join(self.save_path_domain, filename), "w") as h5file: # MESH DATA h5file.create_group(name="mesh") h5file.create_dataset(name="mesh/gridX", data=self.unit_handler.dimensionalize(self.cell_centers[0], "length"), dtype="f8") h5file.create_dataset(name="mesh/gridY", data=self.unit_handler.dimensionalize(self.cell_centers[1], "length"), dtype="f8") h5file.create_dataset(name="mesh/gridZ", data=self.unit_handler.dimensionalize(self.cell_centers[2], "length"), dtype="f8") h5file.create_dataset(name="mesh/gridFX", data=self.unit_handler.dimensionalize(self.cell_faces[0], "length"), dtype="f8") h5file.create_dataset(name="mesh/gridFY", data=self.unit_handler.dimensionalize(self.cell_faces[1], "length"), dtype="f8") h5file.create_dataset(name="mesh/gridFZ", data=self.unit_handler.dimensionalize(self.cell_faces[2], "length"), dtype="f8") h5file.create_dataset(name="mesh/cellsizeX", data=self.unit_handler.dimensionalize(self.cell_sizes[0], "length"), dtype="f8") h5file.create_dataset(name="mesh/cellsizeY", data=self.unit_handler.dimensionalize(self.cell_sizes[1], "length"), dtype="f8") h5file.create_dataset(name="mesh/cellsizeZ", data=self.unit_handler.dimensionalize(self.cell_sizes[2], "length"), dtype="f8") # CURRENT TIME h5file.create_dataset(name="time", data=self.unit_handler.dimensionalize(current_time, "time"), dtype="f8") # COMPUTE TEMPERATURE primes = buffer_dictionary["material_fields"]["primes"] temperature = jnp.expand_dims(self.material_manager.get_temperature(primes[4], primes[0]), axis=0) material_fields = {"primes": jnp.vstack([primes, temperature]), "cons": buffer_dictionary["material_fields"]["cons"]} # CONSERAVITVES AND PRIMITIVES for key in ["primes", "cons"]: if key in self.output_quantities.keys(): h5file.create_group(name=key) for quantity in self.output_quantities[key]: if self.levelset_type == "FLUID-FLUID": for i in range(2): quantity_name = "%s_%d" % (quantity, i) buffer = self.unit_handler.dimensionalize(material_fields[key][self.quantity_index[key][quantity], i, self.nhx, self.nhy, self.nhz], quantity) h5file.create_dataset(name="/".join([key, quantity_name]), data=buffer.T, dtype="f8" if self.is_double else "f4") else: buffer = self.unit_handler.dimensionalize(material_fields[key][self.quantity_index[key][quantity], self.nhx, self.nhy, self.nhz], quantity) h5file.create_dataset(name="/".join([key, quantity]), data=buffer.T, dtype="f8" if self.is_double else "f4") if self.levelset_type != None: # LEVELSET QUANTITIES if "levelset" in self.output_quantities.keys(): h5file.create_group(name="levelset") levelset_quantities = {} levelset, volume_fraction = buffer_dictionary["levelset_quantities"]["levelset"], buffer_dictionary["levelset_quantities"]["volume_fraction"] normal = self.levelset_handler.geometry_calculator.compute_normal(levelset) mask_real, _ = self.levelset_handler.compute_masks(levelset, volume_fraction) if self.levelset_type == "FLUID-FLUID": interface_velocity, interface_pressure, _ = self.levelset_handler.compute_interface_quantities(material_fields["primes"], levelset, volume_fraction) levelset_quantities["interface_velocity"] = self.unit_handler.dimensionalize(interface_velocity, "velocity") levelset_quantities["interface_pressure"] = self.unit_handler.dimensionalize(interface_pressure, "pressure") mask_real = mask_real[0] elif self.levelset_type == "FLUID-SOLID-DYNAMIC": interface_velocity = self.levelset_handler.compute_solid_interface_velocity(current_time) levelset_quantities["interface_velocity"] = self.unit_handler.dimensionalize(interface_velocity, "velocity") levelset_quantities.update({ "levelset": self.unit_handler.dimensionalize(levelset[self.nhx, self.nhy, self.nhz], "length"), "volume_fraction": volume_fraction[...,self.nhx_,self.nhy_,self.nhz_], "mask_real": mask_real[...,self.nhx_,self.nhy_,self.nhz_], "normal": normal[...,self.nhx_,self.nhy_,self.nhz_] }) for quantity in self.output_quantities["levelset"]: h5file.create_dataset(name="levelset/" + quantity, data=levelset_quantities[quantity].T, dtype="f8" if self.is_double else "f4") # CONSERVATIVES AND PRIMITIVES FOR REAL FLUID if "real_fluid" in self.output_quantities.keys(): h5file.create_group(name="real_fluid") for key in ["cons", "primes"]: real_buffer = self.compute_real_buffer(material_fields[key][...,self.nhx,self.nhy,self.nhz], buffer_dictionary["levelset_quantities"]["volume_fraction"][self.nhx_,self.nhy_,self.nhz_]) for quantity in [quant for quant in self.output_quantities["real_fluid"] if quant in self.quantity_index[key].keys()]: real_state = self.unit_handler.dimensionalize(real_buffer[self.quantity_index[key][quantity]], quantity) h5file.create_dataset(name="real_fluid/" + quantity, data=real_state.T, dtype="f8" if self.is_double else "f4") # MISCELLANEOUS - ALWAYS COMPUTED FOR REAL FLUID if "miscellaneous" in self.output_quantities.keys(): h5file.create_group(name="miscellaneous") for quantity in self.output_quantities["miscellaneous"]: computed_quantity = self.compute_miscellaneous(material_fields["primes"], quantity, buffer_dictionary["levelset_quantities"]["volume_fraction"] if self.levelset_type != None else None) h5file.create_dataset(name="miscellaneous/" + quantity, data=computed_quantity.T, dtype="f8" if self.is_double else "f4") # MASS FLOW FORCING if self.is_mass_flow_forcing: h5file.create_group(name="mass_flow_forcing") h5file.create_dataset(name="mass_flow_forcing/scalar_value", data=buffer_dictionary["mass_flow_forcing"]["scalar_value"], dtype="f8" if self.is_double else "f4") h5file.create_dataset(name="mass_flow_forcing/PID_e_int", data=buffer_dictionary["mass_flow_forcing"]["PID_e_int"], dtype="f8" if self.is_double else "f4") h5file.create_dataset(name="mass_flow_forcing/PID_e_new", data=buffer_dictionary["mass_flow_forcing"]["PID_e_new"], dtype="f8" if self.is_double else "f4")
[docs] def write_xdmffile(self, current_time: float) -> None: """Writes an xdmf file for the current time step. The xdmf file corresponds to an h5 file which holds the data buffers. :param current_time: Current simulation time. :type current_time: float """ filename = "data_%.6f" % self.unit_handler.dimensionalize(current_time, "time") h5file_name = filename + ".h5" xdmffile_path = filename + ".xdmf" h5file_path = os.path.join(self.save_path_domain, h5file_name) xdmffile_path = os.path.join(self.save_path_domain, xdmffile_path) xdmf_str = "" # XDMF START xdmf_preamble ='''<?xml version="1.0" ?> <!DOCTYPE Xdmf SYSTEM "Xdmf.dtd" []> <Xdmf Version="3.0"> <Domain> <Grid Name="TimeStep" GridType="Collection" CollectionType="Temporal">''' xdmf_str_start = ''' <Grid Name="SpatialData_%e" GridType="Uniform"> <Time TimeType="Single" Value="%e" /> <Geometry Type="VXVYVZ"> <DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i">%s:mesh/gridFX</DataItem> <DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i">%s:mesh/gridFY</DataItem> <DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i">%s:mesh/gridFZ</DataItem> </Geometry> <Topology Dimensions="%i %i %i" Type="3DRectMesh"/>''' %( # 1 512 128 current_time, current_time, 8 if self.is_double else 4, len(self.cell_faces[0]), h5file_name, 8 if self.is_double else 4, len(self.cell_faces[1]), h5file_name, 8 if self.is_double else 4, len(self.cell_faces[2]), h5file_name, len(self.cell_faces[0]), len(self.cell_faces[1]), len(self.cell_faces[2])) # XDMF QUANTITIES xdmf_quants = [] # CONSERVATIVES AND PRIMITIVES for key in ["cons", "primes"]: if key in self.output_quantities.keys(): for quantity in self.output_quantities[key]: no_phases = 2 if self.levelset_type == "FLUID-FLUID" else 1 for i in range(no_phases): quantity_name = "%s_%d" % (quantity, i) if self.levelset_type == "FLUID-FLUID" else quantity xdmf_quants.append(self.get_xdmf(key, quantity_name, h5file_name, *self.number_of_cells)) # REAL FLUID AND MISCELLANEOUS for key in ["real_fluid", "miscellaneous"]: if key in self.output_quantities.keys(): for quantity in self.output_quantities[key]: xdmf_quants.append(self.get_xdmf(key, quantity, h5file_name, *self.number_of_cells)) xdmf_str_end = '''</Grid>''' # XDMF END xdmf_postamble = '''</Grid> </Domain> </Xdmf>''' # APPEND XDMF SPATIAL TO TIMESERIES if current_time not in self.output_timeseries: self.output_timeseries.append(current_time) self.xdmf_timeseries.append("\n".join([xdmf_str_start] + xdmf_quants + [xdmf_str_end])) # JOIN FINAL XDMF STR AND WRITE TO FILE xdmf_str = "\n".join([xdmf_preamble, xdmf_str_start] + xdmf_quants + [xdmf_str_end, xdmf_postamble]) with open(xdmffile_path, "w") as xdmf_file: xdmf_file.write(xdmf_str)
[docs] def write_timeseries_xdmffile(self) -> None: """Write xdmffile for the complete time series so that visualization tools like Paraview can load the complete time series at once. This is done once at the end of a simulation when every output time stamp is known. """ xdmffile_path = os.path.join(self.save_path_domain, "data_time_series.xdmf") xdmf_str = "" # XDMF START xdmf_str_start ='''<?xml version="1.0" ?> <!DOCTYPE Xdmf SYSTEM "Xdmf.dtd" []> <Xdmf Version="3.0"> <Domain> <Grid Name="TimeSeries" GridType="Collection" CollectionType="Temporal">''' # XDMF END xdmf_str_end = '''</Grid> </Domain> </Xdmf>''' # JOIN FINAL XDMF STR AND WRITE TO FILE xdmf_str = "\n".join([xdmf_str_start] + self.xdmf_timeseries + [xdmf_str_end]) with open(xdmffile_path, "w") as xdmf_file: xdmf_file.write(xdmf_str)
[docs] def get_xdmf(self, group: str, quant: str, h5file_name: str, Nx: int, Ny: int, Nz: int) -> str: """Returns the string for the xdmf file for the given output quantity. :param group: Group name in h5 file under which the quantity is stored. :type group: str :param quant: Name of the output quantity. :type quant: str :param h5file_name: Name of the corresponding h5 file. :type h5file_name: str :param Nx: Resolution in x direction. :type Nx: int :param Ny: Resolution in y direction. :type Ny: int :param Nz: Resolution in z direction. :type Nz: int :return: Xdmf string for the specified quantity. :rtype: str """ if quant in ["velocity", "momentum", "vorticity"]: xdmf ='''<Attribute Name="%s" AttributeType="Vector" Center="Cell"> <DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i %i %i %i">%s:%s/%s</DataItem> </Attribute>''' %(quant, 8 if self.is_double else 4, Nz, Ny, Nx, 3, h5file_name, group, quant) else: xdmf ='''<Attribute Name="%s" AttributeType="Scalar" Center="Cell"> <DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i %i %i">%s:%s/%s</DataItem> </Attribute>''' %(quant, 8 if self.is_double else 4, Nz, Ny, Nx, h5file_name, group, quant) return xdmf
[docs] @partial(jax.jit, static_argnums=(0)) def compute_real_buffer(self, buffer: jnp.ndarray, volume_fraction: jnp.ndarray) -> jnp.ndarray: """ For two-phase simulations, merges the two separate phase buffers into a single real buffer. Calculation is done as a arithmetic average based on the volume fraction. :param buffer: Data buffer. :type buffer: jnp.ndarray :param volume_fraction: Buffer of the volume fraction. :type volume_fraction: jnp.ndarray :return: Combined data buffer of the 'real' fluid. :rtype: jnp.ndarray """ volume_fraction = jnp.stack([volume_fraction, 1.0 - volume_fraction], axis=0) conservatives_real = buffer[...,0,:,:,:] * volume_fraction[0] + buffer[...,1,:,:,:] * volume_fraction[1] return conservatives_real
[docs] @partial(jax.jit, static_argnums=(0,2)) def compute_miscellaneous(self, primes: jnp.ndarray, quantity: str, volume_fraction: Union[jnp.ndarray, None]) -> jnp.ndarray: """Compute miscellaneous output fields for h5 output. :param primes: Buffer of primitive variables. :type primes: jnp.ndarray :param quantity: String identifier of the quantity to be computed. :type quantity: str :param volume_fraction: Buffer of the volume fraction field, only for two-phase simulations. Otherwise None. :type volume_fraction: Union[jnp.ndarray, None] :return: Computed phyiscal output quantity. :rtype: jnp.ndarray """ if self.levelset_type == "FLUID-FLUID": primes_real = self.compute_real_buffer(primes[...,self.nhx__,self.nhy__,self.nhz__], volume_fraction) else: primes_real = primes if quantity == "schlieren": computed_quantity = self.compute_schlieren(primes_real[0:1]) elif quantity == "vorticity": computed_quantity = self.compute_vorticity(primes_real[1:4]) elif quantity == "absolute_vorticity": computed_quantity = self.compute_absolute_vorticity(primes_real[1:4]) elif quantity == "absolute_velocity": computed_quantity = self.compute_absolute_velocity(primes_real[1:4]) elif quantity == "mach_number": computed_quantity = self.compute_mach_number(primes, volume_fraction) return computed_quantity
[docs] @partial(jax.jit, static_argnums=(0)) def compute_absolute_velocity(self, velocity: jnp.ndarray) -> jnp.ndarray: """Computes the absolute velocity field for h5 output. :param velocity: Buffer of velocities. :type velocity: jnp.ndarray :return: Buffer of absolute velocity. :rtype: jnp.ndarray """ absolute_velocity = jnp.sqrt( jnp.sum( jnp.square(velocity), axis=0) ) if self.levelset_type == "FLUID-FLUID": absolute_velocity = absolute_velocity[...,self.nhx_,self.nhy_,self.nhz_] else: absolute_velocity = absolute_velocity[...,self.nhx,self.nhy,self.nhz] return absolute_velocity
[docs] @partial(jax.jit, static_argnums=(0)) def compute_mach_number(self, primes: jnp.ndarray, volume_fraction: Union[jnp.ndarray, None]) -> jnp.ndarray: """Computes the Mach number field for h5 output. :param primes: Buffer of primitive variables. :type primes: jnp.ndarray :param volume_fraction: Buffer of volume fraction. :type volume_fraction: Union[jnp.ndarray, None] :return: Buffer of Mach number. :rtype: jnp.ndarray """ absolute_velocity = jnp.sqrt( jnp.sum( jnp.square(primes[1:4]), axis=0) ) speed_of_sound = self.material_manager.get_speed_of_sound(primes[4], primes[0]) mach_number = absolute_velocity/speed_of_sound if self.levelset_type == "FLUID-FLUID": mach_number = self.compute_real_buffer(mach_number[...,self.nhx__,self.nhy__,self.nhz__], volume_fraction)[...,self.nhx_,self.nhy_,self.nhz_] else: mach_number = mach_number[...,self.nhx,self.nhy,self.nhz] return mach_number
[docs] @partial(jax.jit, static_argnums=(0)) def compute_schlieren(self, density: jnp.ndarray) -> jnp.ndarray: """Computes numerical schlieren field for h5 output. :param density: Buffer of density. :type density: jnp.ndarray :return: Buffer of schlieren. :rtype: jnp.ndarray """ schlieren = [] for i in range(3): if self.levelset_type == "FLUID-FLUID": schlieren.append( self.derivative_stencil_geometry.derivative_xi(density, self.cell_sizes[i], i) if i in self.active_axis_indices else jnp.zeros(density[...,self.nhx_,self.nhy_,self.nhz_].shape) ) else: schlieren.append( self.derivative_stencil_conservatives.derivative_xi(density, self.cell_sizes[i], i) if i in self.active_axis_indices else jnp.zeros(density[...,self.nhx,self.nhy,self.nhz].shape) ) schlieren = jnp.linalg.norm(jnp.vstack(schlieren), axis=0, ord=2) return schlieren
[docs] @partial(jax.jit, static_argnums=(0)) def compute_vorticity(self, velocity: jnp.ndarray) -> jnp.ndarray: """Computes vorticity field for h5 output. :param velocity: Buffer of velocities. :type velocity: jnp.ndarray :return: Buffer of vorticity. :rtype: jnp.ndarray """ if self.levelset_type == "FLUID-FLUID": velocity_grad = jnp.stack([self.derivative_stencil_geometry.derivative_xi(velocity, self.cell_sizes[k], k) if k in self.active_axis_indices else jnp.zeros(velocity[...,self.nhx_,self.nhy_,self.nhz_].shape) for k in range(3)], axis=1) else: velocity_grad = jnp.stack([self.derivative_stencil_conservatives.derivative_xi(velocity, self.cell_sizes[k], k) if k in self.active_axis_indices else jnp.zeros(velocity[...,self.nhx,self.nhy,self.nhz].shape) for k in range(3)], axis=1) du_dy, du_dz = velocity_grad[0,1], velocity_grad[0,2] dv_dx, dv_dz = velocity_grad[1,0], velocity_grad[1,2] dw_dx, dw_dy = velocity_grad[2,0], velocity_grad[2,1] vorticity = jnp.stack([ dw_dy - dv_dz, du_dz - dw_dx, dv_dx - du_dy ], axis=0) return vorticity
[docs] @partial(jax.jit, static_argnums=(0)) def compute_absolute_vorticity(self, velocity: jnp.ndarray) -> jnp.ndarray: """Computes absolute vorticity field for h5 output. :param velocity: Buffer of velocities. :type velocity: jnp.ndarray :return: Buffer of absolute vorticity. :rtype: jnp.ndarray """ absolute_vorticity = jnp.linalg.norm(self.compute_vorticity(velocity), axis=0, ord=2) return absolute_vorticity