Source code for jaxfluids.time_integration.time_integrator

#*------------------------------------------------------------------------------*
#* JAX-FLUIDS -                                                                 *
#*                                                                              *
#* A fully-differentiable CFD solver for compressible two-phase flows.          *
#* Copyright (C) 2022  Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams    *
#*                                                                              *
#* This program is free software: you can redistribute it and/or modify         *
#* it under the terms of the GNU General Public License as published by         *
#* the Free Software Foundation, either version 3 of the License, or            *
#* (at your option) any later version.                                          *
#*                                                                              *
#* This program is distributed in the hope that it will be useful,              *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of               *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                *
#* GNU General Public License for more details.                                 *
#*                                                                              *
#* You should have received a copy of the GNU General Public License            *
#* along with this program.  If not, see <https://www.gnu.org/licenses/>.       *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* CONTACT                                                                      *
#*                                                                              *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de        *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* Munich, April 15th, 2022                                                     *
#*                                                                              *
#*------------------------------------------------------------------------------*

from abc import ABC, abstractmethod
from functools import partial
from typing import List

import jax
import jax.numpy as jnp

[docs] class TimeIntegrator(ABC): """Abstract base class for explicit time integration schemes. All time intergration schemes are derived from TimeIntegrator. """ def __init__(self, nh: int, inactive_axis: List) -> None: self.no_stages = None self.nhx = jnp.s_[:] if "x" in inactive_axis else jnp.s_[nh:-nh] self.nhy = jnp.s_[:] if "y" in inactive_axis else jnp.s_[nh:-nh] self.nhz = jnp.s_[:] if "z" in inactive_axis else jnp.s_[nh:-nh] self.timestep_multiplier = () self.timestep_increment_factor = ()
[docs] def integrate_conservatives(self, cons: jnp.ndarray, rhs: jnp.ndarray, timestep: float) -> jnp.ndarray: """Integrates the conservative variables. :param cons: conservative variables buffer before integration :type cons: jnp.ndarray :param rhs: right-hand side buffer :type rhs: jnp.ndarray :param timestep: timestep adjusted according to sub-stage in Runge-Kutta :type timestep: float :return: conservative variables buffer after integration :rtype: DeviceArray """ cons = cons.at[..., self.nhx, self.nhy, self.nhz].add(timestep * rhs) return cons
[docs] @abstractmethod def integrate(self, cons: jnp.ndarray, rhs: jnp.ndarray, timestep: float, stage: int) -> jnp.ndarray: """Wrapper function around integrate_conservatives. Adjusts the timestep according to current RK stage and calls integrate_conservatives. Implementation in child class. :param cons: conservative variables buffer before integration :type cons: jnp.ndarray :param rhs: right-hand side buffer :type rhs: jnp.ndarray :param timestep: timestep to be integrated :type timestep: float :return: conservative variables buffer after integration :rtype: DeviceArray """ pass
[docs] def prepare_buffer_for_integration(self, cons: jnp.ndarray, init: jnp.ndarray, stage: int) -> jnp.ndarray: """In multi-stage Runge-Kutta methods, prepares the buffer for integration. Implementation in child class. :param cons: Buffer of conservative variables. :type cons: jnp.ndarray :param init: Initial conservative buffer. :type init: jnp.ndarray :param stage: Current stage of the RK time integrator. :type stage: int :return: Sum of initial buffer and current buffer. :rtype: jnp.ndarray """ pass