#*------------------------------------------------------------------------------*
#* JAX-FLUIDS - *
#* *
#* A fully-differentiable CFD solver for compressible two-phase flows. *
#* Copyright (C) 2022 Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams *
#* *
#* This program is free software: you can redistribute it and/or modify *
#* it under the terms of the GNU General Public License as published by *
#* the Free Software Foundation, either version 3 of the License, or *
#* (at your option) any later version. *
#* *
#* This program is distributed in the hope that it will be useful, *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
#* GNU General Public License for more details. *
#* *
#* You should have received a copy of the GNU General Public License *
#* along with this program. If not, see <https://www.gnu.org/licenses/>. *
#* *
#*------------------------------------------------------------------------------*
#* *
#* CONTACT *
#* *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de *
#* *
#*------------------------------------------------------------------------------*
#* *
#* Munich, April 15th, 2022 *
#* *
#*------------------------------------------------------------------------------*
import os
from typing import List, Dict, Tuple
import h5py
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.animation import FuncAnimation
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
[docs]
def load_data(path: str, quantities: List, start: int = 0, end: int = None,
N: int = 1) -> Tuple[List, List, List, Dict]:
"""Loads .h5 result files generated by a jaxfluids simulation.
:param path: Path to the folder containing the .h5 files for the time snapshots
:type path: str
:param quantities: Quantities to load
:type quantities: List
:param start: Starting time snapshot to load, defaults to 0
:type start: int, optional
:param end: Ending time snapshot to load, defaults to None
:type end: int, optional
:param N: Interval of time snapshots to load, defaults to 1
:type N: int, optional
:return: Cell centers coordinates, cell sizes, times and dictionary of buffers
:rtype: Tuple
"""
# SANITY CHECK
available_quantities = {
"primes" : ["density", "velocity", "velocityX", "velocityY", "velocityZ", "pressure", "temperature"],
"cons": ["mass", "momentum", "momentumX", "momentumY", "momentumZ", "energy"],
"levelset": ["levelset", "volume_fraction", "mask_real", "normal", "interface_pressure", "interface_velocity"],
"real_fluid": [ "real_density", "real_velocity", "real_velocityX", "real_velocityY", "real_velocityZ", "real_pressure", "real_temperature",
"real_mass", "real_momentum", "real_momentumX", "real_momentumY", "real_momentumZ", "real_energy" ],
"miscellaneous": ["mach_number", "schlieren", "absolute_velocity", "vorticity", "absolute_vorticity"],
}
for quant in quantities:
flag = False
for key in available_quantities:
if quant in available_quantities[key]:
flag = True
assert flag, "%s not available." % quant
# SEARCH DIRECTORY FOR FILES
files = []
times = []
for file in os.listdir(path):
if file.endswith("h5"):
if "nan" in file:
continue
files.append(file)
times.append(float(os.path.splitext(os.path.basename(file))[0][5:]))
indices = np.argsort(np.array(times))
times = np.array(times)[indices][start:end:N]
files = np.array(files)[indices][start:end:N]
no_times = len(times)
with h5py.File(os.path.join(path, files[0]), "r") as h5file:
# DOMAIN INFORMATION
x = h5file["mesh/gridX"][:]
y = h5file["mesh/gridY"][:]
z = h5file["mesh/gridZ"][:]
dx = h5file["mesh/cellsizeX"][()]
dy = h5file["mesh/cellsizeY"][()]
dz = h5file["mesh/cellsizeZ"][()]
cell_centers = [x, y, z]
cell_sizes = [dx, dy, dz]
nx = len(x)
ny = len(y)
nz = len(z)
# CHECK IF PRESENT H5FILES CONTAIN FLUID-FLUID RESULTS
fluid_fluid = False
if "real_fluid" in h5file.keys():
fluid_fluid = True
elif "primes" in h5file.keys():
first_key = list(h5file["primes"].keys())[0]
if first_key[-1] == "0":
fluid_fluid = True
elif "cons" in h5file.keys():
first_key = list(h5file["cons"].keys())[0]
if first_key[-1] == "0":
fluid_fluid = True
# CREATE BUFFERS
data_dictionary = {}
for quant in quantities:
for key in available_quantities:
if quant in available_quantities[key]:
quantity_type = key
break
# PRIMITIVES
if quantity_type in ["primes", "cons"]:
if fluid_fluid:
if quant in ["velocity", "momentum"]:
data_dictionary[quant] = np.zeros((no_times, 2, 3, nx, ny, nz))
else:
data_dictionary[quant] = np.zeros((no_times, 2, nx, ny, nz))
else:
if quant in ["velocity", "momentum"]:
data_dictionary[quant] = np.zeros((no_times, 3, nx, ny, nz))
else:
data_dictionary[quant] = np.zeros((no_times, nx, ny, nz))
# REAL FLUID
elif quantity_type in ["levelset", "real_fluid"]:
if quant in ["normal", "real_velocity", "real_momentum"]:
data_dictionary[quant] = np.zeros((no_times, 3, nx, ny, nz))
elif quant == "interface_pressure":
data_dictionary[quant] = np.zeros((no_times, 2, nx, ny, nz))
else:
data_dictionary[quant] = np.zeros((no_times, nx, ny, nz))
# MISCELLANEOUS
elif quantity_type == "miscellaneous":
if quant == "vorticity":
data_dictionary[quant] = np.zeros((no_times, 3, nx, ny, nz))
else:
data_dictionary[quant] = np.zeros((no_times, nx, ny, nz))
# MASS FLOW FORCING
elif quant == "mass_flow_forcing":
data_dictionary[quant] = np.zeros(no_times)
# FILL BUFFERS
for i, file in enumerate(files):
print("Loading time snapshot %.4e" % times[i])
with h5py.File(os.path.join(path, file), "r") as h5file:
for quant in quantities:
# LEVELSET
if quant in available_quantities["levelset"]:
data_dictionary[quant][i] = h5file["levelset/" + quant][:].T
# PRIMITIVES
elif quant in available_quantities["primes"]:
for j in range(2 if fluid_fluid else 1):
quant_name_h5 = quant + "_%d" % j if fluid_fluid else quant
s_1 = np.s_[i,j] if fluid_fluid else np.s_[i]
data_dictionary[quant][s_1] = h5file["primes/" + quant_name_h5][:].T
# CONSERVATIVES
elif quant in available_quantities["cons"]:
for j in range(2 if fluid_fluid else 1):
quant_name_h5 = quant + "_%d" % j if fluid_fluid else quant
s_1 = np.s_[i,j] if fluid_fluid else np.s_[i]
data_dictionary[quant][s_1] = h5file["cons/" + quant_name_h5][:].T
# REAL FLUID
elif quant in available_quantities["real_fluid"]:
quant_name_h5 = quant.split("_")[-1]
data_dictionary[quant][i] = h5file["real_fluid/" + quant_name_h5][:].T
# MISCELLANEOUS
elif quant in available_quantities["miscellaneous"]:
data_dictionary[quant][i] = h5file["miscellaneous/" + quant][:].T
# MASS FLOW FORCING
elif quant == "mass_flow_forcing":
data_dictionary[quant][i] = h5file["mass_flow_forcing/scalar_value"][()]
return cell_centers, cell_sizes, times, data_dictionary
[docs]
def create_contourplot(data_dictionary: Dict, cell_centers: List, times: np.ndarray,
levelset: np.ndarray = None, nrows_ncols: Tuple = None, plane: str = "xy",
plane_value: float = 0.0, static_time: float = None, save_fig: str = None, save_anim: str = None,
interval: int = 50, dpi: int = 100) -> None:
"""Creates a subplot of 2D pcolormesh plots and subsequently creates an animation.
If the levelset argument is provided, the interface
will be visualized as an isoline. If the static_time argument is provided, no animation is created,
but a plot at this time will be shown.
:param data_dictionary: Data buffers
:type data_dictionary: Dict
:param cell_centers: Cell center coordinates
:type cell_centers: List
:param times: Times
:type times: np.ndarray
:param levelset: Levelset buffer, defaults to None
:type levelset: np.ndarray, optional
:param nrows_ncols: Shape of the subplot, defaults to None
:type nrows_ncols: Tuple, optional
:param plane: Plane to slice, defaults to "xy"
:type plane: str, optional
:param plane_value: Value of remaining axis to slice, defaults to 0.0
:type plane_value: float, optional
:param static_time: Static time to plot, defaults to None
:type static_time: float, optional
:param save_fig: Path to save the static time figure, defaults to None
:type save_fig: str, optional
:param save_anim: Path to save the animation, defaults to None
:type save_anim: str, optional
:param interval: Interval in ms for the animation, defaults to 50
:type interval: int, optional
:param dpi: Resolution in dpi, defaults to 100
:type dpi: int, optional
"""
# CELL CENTERS
x, y, z = cell_centers
# ROWS AND COLUMNS
nrows_ncols = (1, len(data_dictionary.keys())) if nrows_ncols == None else nrows_ncols
# SLICE PLANE
if plane == "xy":
X1, X2 = np.meshgrid(x,y,indexing="ij")
index = np.argmin(np.abs(z - plane_value))
data = [ data_dictionary[quant][:,:,:,index] for quant in data_dictionary.keys() ]
levelset = levelset[:,:,:,index] if type(levelset) != type(None) else None
elif plane == "xz":
X1, X2 = np.meshgrid(x,z,indexing="ij")
index = np.argmin(np.abs(y - plane_value))
data = [ data_dictionary[quant][:,:,index,:] for quant in data_dictionary.keys() ]
levelset = levelset[:,:,index,:] if type(levelset) != type(None) else None
elif plane == "yz":
X1, X2 = np.meshgrid(y,z,indexing="ij")
index = np.argmin(np.abs(x - plane_value))
data = [ data_dictionary[quant][:,index,:,:] for quant in data_dictionary.keys() ]
levelset = levelset[:,index,:,:] if type(levelset) != type(None) else None
else:
assert False, "plane does not exist"
# TITLES
titles = data_dictionary.keys()
# MINMAX VALUES
minmax = []
for i in range(len(data)):
d = data[i]
minmax.append([np.min(d), np.max(d) + np.finfo(float).eps])
# STATIC PLOT
if static_time != None:
index_time = np.argmin(np.abs(times - static_time))
else:
index_time = 0
# CREATE SUBPLOTS
fig, axes, quadmesh_list, pciset_list = create_figure_2D(data, levelset,
nrows_ncols, titles, X1, X2, minmax, index_time)
# SAVE STATIC PLOT
if save_fig != None:
fig.savefig("%s.pdf" % save_fig, bbox_inches="tight")
if static_time != None:
plt.show()
# CREATE ANIMATION
ani = FuncAnimation(fig, update_2D, frames=data[0].shape[0], fargs=(X1, X2,
data, levelset, axes, quadmesh_list, pciset_list), interval=interval, blit=True, repeat=True)
if save_anim is not None:
if save_anim.endswith((".avi", ".mp4")):
ani.save(save_anim, dpi=dpi)
else:
ani.save("%s.gif" % save_anim, dpi=dpi)
plt.show()
[docs]
def create_lineplot(data_dictionary: List, cell_centers: List, times: np.ndarray, nrows_ncols: Tuple,
axis: str="x", values: List=[0.0,0.0], static_time: float = None, save_fig: str = None,
save_anim: str = None, interval: int = 50, dpi: int = 100) -> None:
"""Creates a subplot of 1D line plots. If the levelset argument is provided, the interface
will be visualized as an isoline. If the static_time argument is provided, no animation is created,
but a plot at this time will be shown.
:param data_dictionary: Data buffers
:type data_dictionary: List
:param cell_centers: Cell center coordinates
:type cell_centers: List
:param times: Time buffer
:type times: np.ndarray
:param nrows_ncols: Shape of subplots
:type nrows_ncols: Tuple
:param axis: Axis to slice, defaults to "x"
:type axis: str, optional
:param values: Values of remaining axis to slice, defaults to [0.0,0.0]
:type values: List, optional
:param static_time: Static time, defaults to None
:type static_time: float, optional
:param save_fig: Path to save static plot, defaults to None
:type save_fig: str, optional
:param save_anim: Path to save animation, defaults to None
:type save_anim: str, optional
:param interval: Interval in ms for animation, defaults to 50
:type interval: int, optional
:param dpi: Resolution in dpi, defaults to 100
:type dpi: int, optional
"""
# SLICE AXIS
x, y, z = cell_centers
if axis == "x":
coord = x
index1 = np.argmin(np.abs(y - values[0]))
index2 = np.argmin(np.abs(z - values[1]))
data = [ data_dictionary[quant][:,:,index1,index2] for quant in data_dictionary.keys() ]
elif axis == "y":
coord = y
index1 = np.argmin(np.abs(x - values[0]))
index2 = np.argmin(np.abs(z - values[1]))
data = [ data_dictionary[quant][:,index1,:,index2] for quant in data_dictionary.keys() ]
elif axis == "z":
coord = z
index1 = np.argmin(np.abs(x - values[0]))
index2 = np.argmin(np.abs(y - values[1]))
data = [ data_dictionary[quant][:,index1,index2,:] for quant in data_dictionary.keys() ]
else:
assert False, "axis does not exist"
# STATIC TIME INDEX
if static_time != None:
index_time = np.argmin(np.abs(times - static_time))
else:
index_time = 0
# CREATE SUBPLOT
fig, axes = plt.subplots(nrows_ncols[0], nrows_ncols[1])
if type(axes) == np.ndarray:
axes = axes.ravel()
else:
axes = [axes]
fig.set_size_inches([20,8])
# PLOT DATA
lines = []
for i, (quant, ax) in enumerate(zip(data_dictionary.keys(), axes)):
lines.append( ax.plot(coord, data[i][index_time], "o", label=quant)[0] )
ax.set_ylim([np.min(data[i]), np.max(data[i])])
ax.set_title(quant)
# SAVE STATIC PLOT
if save_fig != None:
fig.savefig("%s.pdf" % save_fig, bbox_inches="tight")
if static_time != None:
plt.show()
# CREATE/SAVE ANIMATION
ani = FuncAnimation(fig, update_1D, frames=len(times), fargs=(data, lines), interval=interval, blit=False, repeat=True)
if save_anim is not None:
if save_anim.endswith((".avi", ".mp4")):
ani.save(save_anim, dpi=dpi)
else:
ani.save("%s.gif" % save_anim, dpi=dpi)
plt.show()
[docs]
def update_2D(i: int, X: np.ndarray, Y: np.ndarray, data: List, levelset: np.ndarray,
axes: np.ndarray, quadmesh_list: List, pciset_list:List) -> List:
"""Update function for FuncAnimation for the create_contourplot() function.
:param i: Time snapshot index
:type i: int
:param X: X meshgrid
:type X: np.ndarray
:param Y: Y Meshgrid
:type Y: np.ndarray
:param data: Data buffers
:type data: List
:param levelset: Levelet buffer
:type levelset: np.ndarray
:param axes: Axes
:type axes: np.ndarray
:param quadmesh_list: Quadmesh objects returned by pcolormesh()
:type quadmesh_list: List
:param pciset_list: Pciset objects return by contour()
:type pciset_list: List
:return: List of collections
:rtype: List
"""
list_of_collections = []
for z, ax, quadmesh, pciset in zip(data, axes, quadmesh_list, pciset_list):
if type(levelset) != type(None) and type(pciset[0]) != type(None):
for tp in pciset[0].collections:
tp.remove()
pciset[0] = ax.contour(X, Y, np.squeeze(levelset[i,:,:]), levels=[0.0], colors="black", linewidths=4)
list_of_collections += pciset[0].collections
quadmesh.set_array(z[i,:,:].flatten())
list_of_collections.append(quadmesh)
return list_of_collections
[docs]
def update_1D(i: int, data: List, lines: List) -> List:
"""Update function for FuncAnimation for the create_lineplot() function.
:param i: Time snapshot index
:type i: int
:param data: Data buffers
:type data: List
:param lines: Line objects
:type lines: List
:return: List of collections
:rtype: List
"""
list_of_collections = []
for d, line in zip(data, lines):
line.set_ydata(d[i,:])
list_of_collections.append(line)
return list_of_collections
[docs]
def generate_png(X: np.ndarray, Y: np.ndarray, data_dictionary: Dict[str, np.ndarray], save_path: str,
levelset: np.ndarray = None, nrows_ncols: Tuple = None, dpi=200) -> None:
"""Generates a subplot of pcolormesh and saves it as png. If the levelset argument is provided,
the interface is plotted as an isoline.
:param X: X meshgrid
:type X: np.ndarray
:param Y: Y meshgrid
:type Y: np.ndarray
:param data_dictionary: Data buffer
:type data_dictionary: Dict[str, np.ndarray]
:param save_path: Save path
:type save_path: str
:param levelset: Levelset buffer, defaults to None
:type levelset: np.ndarray, optional
:param nrows_ncols: Shape of the subplot, defaults to None
:type nrows_ncols: Tuple, optional
:param dpi: Resolution in dpi, defaults to 200
:type dpi: int, optional
"""
quantities = list(data_dictionary.keys())
no_timesnapshots = len(data_dictionary[quantities[0]])
kwargs = []
for quant in quantities:
if quant == "schlieren":
kwargs.append( {"cmap": "binary", "norm": matplotlib.colors.LogNorm(vmin=np.min(data_dictionary[quant]), vmax=np.max(data_dictionary[quant]))} )
else:
kwargs.append( {"cmap": "seismic", "vmin": np.min(data_dictionary[quant]), "vmax": np.max(data_dictionary[quant])} )
fontsize = 20
nrows_ncols = (1,len(quantities)) if nrows_ncols == None else nrows_ncols
assert np.prod(nrows_ncols) == len(quantities), "Amount of requested axes and available quantities do not match."
for i in range(no_timesnapshots):
filename = "image_%04d" % i
fig, axes = plt.subplots(nrows_ncols[0], nrows_ncols[1])
fig.set_size_inches(nrows_ncols[0]*10,nrows_ncols[1]*10,True)
fig.tight_layout()
axes = axes.flatten()
for ax, quantity, kw in zip(axes, data_dictionary, kwargs):
pcolormesh = ax.pcolormesh(X, Y, data_dictionary[quantity][i,:,:,0], **kw, shading="auto")
if type(levelset) != type(None):
ax.contour(X, Y, levelset[i,:,:,0], levels=[0.0], colors="black", linewidths=3)
# AXIS TICKS
ax.set_title(quantity.upper(), fontsize=fontsize, pad=10)
ax.tick_params(labelsize=fontsize, left=False, bottom=False, labelleft=False, labelbottom=False)
ax.set_aspect("equal")
print("Saving %s" % filename)
fig.savefig(os.path.join(save_path, filename), dpi=dpi, bbox_inches="tight")
plt.clf()
plt.cla()
plt.close(fig)
del fig, axes
[docs]
def create_xdmf_from_h5(path: str) -> None:
"""Creates .xdmf files from .h5 files for Paraview visualization.
:param path: Path to .h5 files.
:type path: str
"""
# SEARCH DIRECTORY FOR FILES
files = []
times = []
for file in os.listdir(path):
if file.endswith("h5"):
if "nan" in file:
continue
files.append(file)
times.append(float(os.path.splitext(os.path.basename(file))[0][5:]))
indices = np.argsort(np.array(times))
times = np.array(times)[indices]
files = np.array(files)[indices]
no_times = len(times)
# DOMAIN INFORMATION
with h5py.File(os.path.join(path, files[0]), "r") as h5file:
x = h5file["mesh/gridX"][:]
y = h5file["mesh/gridY"][:]
z = h5file["mesh/gridZ"][:]
dx = h5file["mesh/cellsizeX"][()]
dy = h5file["mesh/cellsizeY"][()]
dz = h5file["mesh/cellsizeZ"][()]
cell_centers = [x, y, z]
cell_sizes = [dx, dy, dz]
number_of_cells = [len(x), len(y), len(z)]
cell_faces = []
for xi, dxi in zip(cell_centers, cell_sizes):
cell_faces.append( np.linspace(xi[0] - dxi/2, xi[-1] + dxi/2, len(xi)+1) )
for i, (file, time) in enumerate(zip(files, times)):
filename = "data_%.6f" % time
h5file_name = filename + ".h5"
xdmffile_path = filename + ".xdmf"
print("Writing file %s" % xdmffile_path)
xdmffile_path = os.path.join(path, xdmffile_path)
# XDMF QUANTITIES
xdmf_quants = []
with h5py.File(os.path.join(path, file), "r") as h5file:
# CONSERVATIVES AND PRIMITIVES
for key in ["cons", "primes", "real_fluid", "miscellaneous", "levelset"]:
if key in h5file.keys():
for quantity in h5file[key]:
xdmf_quants.append(get_xdmf(key, quantity, h5file_name, *number_of_cells))
# XDMF START
xdmf_str_start ='''<?xml version="1.0" ?>
<!DOCTYPE Xdmf SYSTEM "Xdmf.dtd" []>
<Xdmf Version="3.0">
<Domain>
<Grid Name="Cube">
<Geometry Type="VXVYVZ">
<DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i">%s:mesh/gridFX</DataItem>
<DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i">%s:mesh/gridFY</DataItem>
<DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i">%s:mesh/gridFZ</DataItem>
</Geometry>
<Topology Dimensions="%i %i %i" Type="3DRectMesh"/>''' %( # 1 512 128
8, len(cell_faces[0]), h5file_name,
8, len(cell_faces[1]), h5file_name,
8, len(cell_faces[2]), h5file_name,
len(cell_faces[0]), len(cell_faces[1]), len(cell_faces[2]))
# XDMF END
xdmf_str_end = '''</Grid>
</Domain>
</Xdmf>'''
# JOIN FINAL XDMF STR AND WRITE TO FILE
xdmf_str = "\n".join([xdmf_str_start] + xdmf_quants + [xdmf_str_end])
with open(xdmffile_path, "w") as xdmf_file:
xdmf_file.write(xdmf_str)
[docs]
def get_xdmf(group: str, quant: str, h5file_name: str, Nx: int, Ny: int, Nz: int) -> str:
if quant in ["velocity", "momentum", "vorticity", "normal"]:
xdmf ='''<Attribute Name="%s" AttributeType="Vector" Center="Cell">
<DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i %i %i %i">%s:%s/%s</DataItem>
</Attribute>''' %(quant, 8 , Nz, Ny, Nx, 3, h5file_name, group, quant)
else:
xdmf ='''<Attribute Name="%s" AttributeType="Scalar" Center="Cell">
<DataItem Format="HDF" NumberType="Float" Precision="%i" Dimensions="%i %i %i">%s:%s/%s</DataItem>
</Attribute>''' %(quant, 8 , Nz, Ny, Nx, h5file_name, group, quant)
return xdmf