Source code for jaxfluids.stencils.reconstruction.weno3_nn_opt2

#*------------------------------------------------------------------------------*
#* JAX-FLUIDS -                                                                 *
#*                                                                              *
#* A fully-differentiable CFD solver for compressible two-phase flows.          *
#* Copyright (C) 2022  Deniz A. Bezgin, Aaron B. Buhendwa, Nikolaus A. Adams    *
#*                                                                              *
#* This program is free software: you can redistribute it and/or modify         *
#* it under the terms of the GNU General Public License as published by         *
#* the Free Software Foundation, either version 3 of the License, or            *
#* (at your option) any later version.                                          *
#*                                                                              *
#* This program is distributed in the hope that it will be useful,              *
#* but WITHOUT ANY WARRANTY; without even the implied warranty of               *
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                *
#* GNU General Public License for more details.                                 *
#*                                                                              *
#* You should have received a copy of the GNU General Public License            *
#* along with this program.  If not, see <https://www.gnu.org/licenses/>.       *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* CONTACT                                                                      *
#*                                                                              *
#* deniz.bezgin@tum.de // aaron.buhendwa@tum.de // nikolaus.adams@tum.de        *
#*                                                                              *
#*------------------------------------------------------------------------------*
#*                                                                              *
#* Munich, April 15th, 2022                                                     *
#*                                                                              *
#*------------------------------------------------------------------------------*

from typing import List

import haiku as hk
import jax
import jax.numpy as jnp

from jaxfluids.stencils.spatial_reconstruction import SpatialReconstruction

[docs] class WENO3NNOPT2(SpatialReconstruction): def __init__(self, nh: int, inactive_axis: List) -> None: super(WENO3NNOPT2, self).__init__(nh=nh, inactive_axis=inactive_axis) self.dr_ = [ [1/3, 2/3], [2/3, 1/3], ] self.cr_ = [ [[-0.5, 1.5], [0.5, 0.5]], [[-0.5, 1.5], [0.5, 0.5]], ] self._c_eno = 2e-4 self._stencil_size = 4 self._slices = [ [ [ jnp.s_[:, self.n-2+j:-self.n-1+j, self.nhy, self.nhz], jnp.s_[:, self.n-1+j:-self.n+j, self.nhy, self.nhz], jnp.s_[:, self.n+j:-self.n+1+j, self.nhy, self.nhz], ], [ jnp.s_[:, self.nhx, self.n-2+j:-self.n-1+j, self.nhz], jnp.s_[:, self.nhx, self.n-1+j:-self.n+j, self.nhz], jnp.s_[:, self.nhx, self.n+j:-self.n+1+j, self.nhz], ], [ jnp.s_[:, self.nhx, self.nhy, self.n-2+j:-self.n-1+j,], jnp.s_[:, self.nhx, self.nhy, self.n-1+j:-self.n+j, ], jnp.s_[:, self.nhx, self.nhy, self.n+j:-self.n+1+j, ], ], ] for j in range(2)] self._get_nn()
[docs] def set_slices_stencil(self) -> None: self._slices = [ [ [ jnp.s_[..., 0, None:None, None:None], jnp.s_[..., 1, None:None, None:None], jnp.s_[..., 2, None:None, None:None], ], [ jnp.s_[..., None:None, 0, None:None], jnp.s_[..., None:None, 1, None:None], jnp.s_[..., None:None, 2, None:None], ], [ jnp.s_[..., None:None, None:None, 0], jnp.s_[..., None:None, None:None, 1], jnp.s_[..., None:None, None:None, 2], ], ], [ [ jnp.s_[..., 3, None:None, None:None], jnp.s_[..., 2, None:None, None:None], jnp.s_[..., 1, None:None, None:None], ], [ jnp.s_[..., None:None, 3, None:None], jnp.s_[..., None:None, 2, None:None], jnp.s_[..., None:None, 1, None:None], ], [ jnp.s_[..., None:None, None:None, 3], jnp.s_[..., None:None, None:None, 2], jnp.s_[..., None:None, None:None, 1], ], ], ]
[docs] def reconstruct_xi(self, buffer: jnp.ndarray, axis: int, j: int, dx: float = None, **kwargs) -> jnp.ndarray: s1_ = self._slices[j][axis] dx1 = jnp.abs( buffer[s1_[1]] - buffer[s1_[0]] ) dx2 = jnp.abs( buffer[s1_[2]] - buffer[s1_[1]] ) dx3 = jnp.abs( buffer[s1_[2]] - buffer[s1_[0]] ) dx4 = jnp.abs( buffer[s1_[0]] - 2*buffer[s1_[1]] + buffer[s1_[2]] ) x = jnp.stack([dx1, dx2, dx3, dx4], axis=-1) x /= (jnp.maximum(x[:,:,:,:,:1], x[:,:,:,:,1:2]) + self.eps) omega_z_ = self.net.apply(self.params, x) omega_z_ = jax.nn.relu(omega_z_ - self._c_eno) omega_z_ /= jnp.sum(omega_z_, axis=-1, keepdims=True) p_0 = self.cr_[j][0][0] * buffer[s1_[0]] + self.cr_[j][0][1] * buffer[s1_[1]] p_1 = self.cr_[j][1][0] * buffer[s1_[1]] + self.cr_[j][1][1] * buffer[s1_[2]] cell_state_xi_j = omega_z_[:,:,:,:,1] * p_0 + omega_z_[:,:,:,:,0] * p_1 return cell_state_xi_j
def _get_nn(self) -> None: w = jnp.array([ [0.40507858789796725, -1.1238033098043017, -0.07371562056446773, 0.6197903065864093], [0.7930257586191607, -0.5625348147829258, 0.8465590229867586, 0.06934803250753911], [-1.6875939465011658, 0.3292199251637992, 0.6875138682901961, 0.231016713716645], [0.10912183489934689, 0.2040961019258102, -0.6724587051169624, 0.1312559439348201], [-0.3457073157356567, -1.577544807826582, 0.5509198901354974, 0.8351369118279482], [-0.37393270139419627, 0.5598127637781741, 0.27320941345108146, 0.8869608072658418], [0.8098786896641792, 0.169795217035407, -0.5372949484432187, 0.15453353199037187], [1.0839428877188175, 0.5710929826178207, -0.5953756614768302, 0.5340795193438068], [0.5166080626653307, 0.19059321138527613, -0.12020009736086794, -0.2785721649378586], [0.6506865528977455, -0.5424805693979798, 0.3941748861527991, 0.36748233206516395], [0.524897934490327, 0.11506210069998385, -0.6009923660171071, 0.2287483366027391], [0.8388485695741124, -0.6359208318796294, 0.7373071414776674, 0.09497150456206503], [-0.42399921079425185, -0.8680738398303438, -0.6417461344232419, -0.7735348108564313], [0.2915992280213219, -0.8328684447293676, -0.42670822211750586, 0.815257325426268], [1.0245327666871027, 0.2488354688859932, 1.161907091201582, -0.9344205318644166], [0.3853893818171449, -0.5444376846503562, -0.39079519399401663, 0.5409307745135028], ], dtype=jnp.float64).transpose() w1 = jnp.array([ [-0.6867448131153827, 0.5815513043168663, 0.12472620573808167, -0.09514026577588723, -0.8582725029386706, -0.14471040125647194, 0.0693429939306645, -0.14202011173748028, 0.19108464106535772, -0.16684878976327833, 0.16984569067929595, 0.10047832646320387, 0.3153504002155991, -0.4409794087401472, 0.7159675351797941, -0.6472042446455757], [0.25598434385639846, 0.2924590638828298, 0.15686621245994514, -0.3999096307407229, 0.5879345035507688, -0.1303807246107144, -0.042125090119133304, -0.5359666252119695, 0.14567316401367708, -0.307905592962315, -0.5035822807059602, 0.30354018336137517, -0.19114699046394001, -0.25560165124070755, 0.4140476810339576, -0.3139059266968657], [0.04931941442363547, 0.42081550312855154, 0.19839695357598913, 0.2147384236063682, 0.4798402263136694, -0.302871328795308, 0.008479507226946244, -0.484439404748177, 0.05766963499483049, 0.23406566886122618, -0.48310265933937996, 0.21714892360446778, 0.23686420174208417, -0.04346353062517073, 0.08295981300780145, 0.25467847712980385], [-0.10606460702250393, 0.1360114312531342, -0.3466175303554808, 0.04748845846706956, -0.30256915374228743, -0.19698755571116694, -0.38835299377340865, -0.6744351255429994, 0.27810131946803684, 0.3440400972929206, -0.08576255038954103, 0.6253131015099894, -0.34306343661238725, -0.49171672201713107, 0.7413160377747086, -0.6018144383925759], [0.3362148156723888, -0.17538147729740328, 0.23385762816656955, -0.28637505585429174, 0.5734142860597128, -0.4654025808963523, -0.2526530345453943, -0.41569870254200975, -0.44372457643936186, 0.3326457795637245, -0.03862362829354754, 0.5187234029828613, 0.3603121260262233, 0.5176454505698286, -0.2046084510357694, 0.2490394321789764], [-0.22282319086557204, -0.12331735126883088, -0.04881263948212592, 0.10489217056547881, 0.13633297223114244, -0.37866838632586763, -0.15072815358243224, -0.4329889556196099, 0.21277239435469658, -0.36684710608938553, -0.1326218792795665, -0.09610391049541431, -0.1313446480188124, 0.029274006381572036, -0.05832588226216447, -0.4481671940670486], [0.15901584364081886, -0.18442203612739388, 1.010736491366019, -0.06105386169156317, 0.10250697349588397, 0.317729974287565, -0.09032523164359285, -0.5597691039280749, -0.202389872687217, -0.18017776104880975, -0.05391953505364517, -0.7132008489017413, -0.41234048323866823, 0.23626333010541048, -0.7490903262962241, 0.0853534736219293], [1.2288842669282238, -0.1781522819219545, -0.39991716688099505, 0.09024699002470082, 1.088897864322467, 0.18566390313564027, 0.32351788177560165, 0.450773375365827, 0.21092718656565712, 0.4332396207110078, 0.26812663364167444, 0.34272707794803775, 0.3368167332104344, 0.9120772290800091, -1.371239864217059, 1.1106947298430088], [0.009681468243199404, -1.1980815415893973, 0.3850403437706827, -0.30196145900118365, 0.5116131127871133, 0.5077391550810565, -0.231073748570583, -0.2115859410886229, -0.13138968077202287, -0.5367018129935034, 0.14509991689038945, -1.242566064196453, 0.030884651763175234, 0.5263408434094645, -0.26382048495066673, -0.3230218882556699], [-0.4319078552505818, -0.43406243071567846, 1.09657923818615, -0.23881641721146615, 0.6346093451443459, 0.08231776039962466, -0.6940823557554685, -0.6433277909529731, 0.006240600581396445, -0.4223884227319635, -0.7560336622323829, -0.44020645707027395, 0.16522556464579247, -0.325386071770434, -0.4855309551128407, -0.23886493949033488], [0.55552575785334, 0.48738265145824894, 0.5671326425249802, -0.3157169511310578, 1.072696174813921, 0.38884352753773965, -0.6562727204384043, -0.4433377765799739, -0.3532972617324838, -0.1109383528337637, -0.380024702827204, 0.1688757887032817, -0.592430855785699, 0.3151394469962289, -0.6334047596253615, 0.3749731598387512], [-0.06380307634387775, -0.22004397176318183, 0.861953056352305, -0.022907372199960228, 0.00807791591262222, 0.03591242802854111, -0.25675960933370107, -0.20427239352420398, -0.5293714366738514, -0.4409304869441713, 0.10407305114573313, -0.0497069470103045, -0.044246752155309474, -0.13708515184868442, -0.13240119099526648, -0.08252752088703741], [-0.17030919188718194, 0.12985357633457414, -0.21298657196607662, -0.08198966672079237, -0.12543289567890498, -0.43550004091913314, -0.003027406345337324, -0.259570485598751, 0.17837655170057992, 0.2451724081305258, -0.2264436614696502, 0.14324174792029915, -0.03387317739166384, 0.4566799695801795, -0.5692360457237575, 0.36586577679873133], [-0.3360543739769236, 0.07225716315347526, 0.10575298210966604, 0.12658766840872324, 0.029863516573593308, 0.08117110151219442, -0.0731756396469054, -0.2747403771855847, -0.193684962487218, -0.21142953916018375, -0.36670671648689557, 0.10158770025440761, -0.19561461589666734, 0.278274790007805, -0.6169938397198013, 0.020944104369647037], [-0.09942984441243542, -0.4506051572333583, -0.07446657965169429, 0.3988818586779608, 0.23585163792911798, -0.08615682323374707, -0.15668663896883833, -0.1651098591709022, -0.3356579662527265, -0.07445069202002452, 0.24892441860406148, -0.011615277672059748, -0.148646383116846, -0.2240991359503305, 0.0010432203765337207, -0.14205539414864182], [0.8006697293108059, 0.00018753968696219847, 0.9955251031521111, -0.41218282788057853, 1.1972721626730496, -0.4006942812428122, -0.6089590583429121, -0.8631012327324826, -0.29134837945901426, 0.3212744332291301, -0.29739041321556603, 0.2924173679812099, -0.5789845337104051, 0.3180317008567941, -0.0798161893518069, -0.29623412953423733], ], dtype=jnp.float64).transpose() w2 = jnp.array([ [-0.8503863037662435, 0.26973208467202026, 0.07966412814352908, -0.8658865188933688, 1.1495530217160335, -0.07538366652164888, -0.051422009285308205, 0.9790228404107725, -0.45097152653464206, -0.7868869515993664, 0.8391231489143977, 0.21188836856761875, 0.5632085983527626, -0.2847075319720688, 0.13753507297624043, 0.9798165935788753], [-0.45682769318340843, -0.036525317167452914, -0.2318230991265981, -0.6938948736218584, 0.18639919429407467, -0.08692137976843799, 0.4271557260912461, 0.2338891490372102, 0.23542278056425953, 0.281810660179581, 0.15098081176465322, 0.5001903365969637, 0.47276522332312615, -0.21156196754032178, -0.26441256307355193, -0.01316437747527656], [-0.47966232225774, -0.33614322192456625, 0.5527762347571665, -0.38867828870805565, 0.5410980467841571, 0.11540847432348225, -0.11564032945529601, 0.3662821914751811, -0.17798823117584803, -0.22999049581084044, 0.44206744951544447, 0.17616899465839322, -0.08401430965747712, -0.27572034389696437, -0.29526376969654633, 0.5390407145356501], [0.33651303329475424, -0.2625623348024787, -0.05662055240142085, -0.08263908486613132, -0.02711758813854524, 0.26313810383087943, 0.12169390408965944, 0.4092023703649926, -0.23277833778857812, 0.2786806994258514, 0.09738092831869469, 0.2627158996901284, 0.25437142889247705, 0.09642521453352922, 0.49010823491717204, -0.006074281465504539], [-0.3638451160189469, 0.4831194724176685, 0.3818622476444646, 0.2661715563422579, 0.3134700933820883, -0.42042791658219997, 0.05477459029755421, -0.18822440287853998, -0.36039899564912325, -0.22080260091725798, 0.33737377062555535, -0.4089619699638205, -0.001797980195123386, -0.3279778550557059, -0.36707722140288945, 0.4272713756164122], [0.22739460374608986, 0.08073984940667196, 0.28294000461116964, -0.15330560228510942, -0.025460965375753767, -0.26588529697983276, 0.1740175012698221, 0.030864410340264593, -0.075615602281047, 0.14391610838238905, 0.3343806905449929, 0.06242522243019465, -0.12714582995069187, 0.35705181738640546, 0.32688987936549463, 0.11307004065183536], [-0.10085937651296571, -0.16558459105218393, 0.2799860763749076, -0.23520594567790945, -0.3595142398796619, 0.27959380399797606, -0.057934232989528836, -0.24700787600201637, -0.20277391917134385, 0.1745148448472533, -0.23144012045352091, -0.07776276627380775, -0.13130535608245342, -0.13650675543344434, -0.3334147496132178, 0.045684811163875634], [-0.1280107502374298, -0.313264739617544, -0.07086366896849243, -1.4166270333527426, 0.3629840712639312, -0.12894250828517895, 0.6892314171107886, 0.43220159424221044, 0.726858895519226, 0.7658025001460358, 0.2556005056207057, 0.5616048455905793, -0.23453347256630908, -0.12353533570435427, -0.07283126806225097, 0.354922517887608], [-0.05600316563708509, -0.302568376886732, -0.3034615753299236, -1.749603857174122, -0.18277014622811583, 0.3243303236732084, 0.987677111484863, -0.34822747616526173, 0.6746123217747416, 1.4336167682825305, 0.18388256202543724, 1.317414007169706, -0.08395969734448974, 0.5641086635579612, 0.6326894524646852, 0.4067340434542997], [-0.2553791836590051, -0.035585588247976835, -0.6008985731854996, -1.6076502942481172, -0.420309929408639, 0.040154285545409486, 0.7036324980588307, -0.5234350082611118, 0.9444158898256386, 1.49620817248008, -0.005420579514536696, 0.8226455944507565, 0.2043063233960032, 0.5516683915340641, 0.10811084759508162, 0.7573676312853309], [0.3952441290064368, -0.08197590906541201, 0.4742425308603112, 0.12909677697509667, -0.09778749291814226, -0.21428635730430712, 0.2518560771333552, 0.0430880022633686, -0.34294736487133415, -0.3883980233811578, 0.28230825119321645, 0.13597232781367566, 0.3198949234371617, -0.17425052935589203, 0.3001814674220904, 0.11999221570345685], [-0.3394988609436981, 0.4450757862566428, 0.3883776323913531, 0.6033192602734934, -0.330605787553949, -0.3039758963642952, 0.3339389184955664, -1.185249934742437, 0.3772774637143086, 0.5262690792199758, 0.14159969353513202, 0.2278647779768446, -0.3364371614428228, -0.31885624025612996, 0.40682941997723493, 0.3184524655357961], [-0.2004113305938361, -0.5160984754659198, -0.41117325461795473, 0.1187011432405939, -0.020815395321570396, 0.12133386301577685, 0.456446878707769, 0.2046776227718532, 0.18685395539330546, 0.8169487384698053, 0.5734201358781507, -0.28092383162590717, -0.10161694125370713, -0.09170941393805067, -0.3438320456550242, -0.24190006641714637], [-0.5867052808521479, -0.24905993543305086, -0.007894901023279215, 0.17849035708769956, -0.20289093545912645, -0.020175655540982978, 0.5556763101439157, -0.7056814087563544, 0.03429057057746626, 0.7564588003286788, 0.3498718416352649, 0.3357835156668993, -0.3323015730280598, 0.43726171529886276, 0.006942164304266475, 0.31696278631408586], [0.0073176892223145535, -0.3132720547759719, 0.4343701381020824, 0.37163779979624517, 0.1271626570931177, -0.1008318807985456, 0.19247929941650477, 0.224241479817544, -0.029629632482364304, 0.2589529207511819, -0.40260892343668725, -0.28601646006073816, -0.2530608866364776, -0.3314749286164843, 0.33300952808836914, 0.036790519034176954], [0.11734447732537494, 0.12174707127590667, -0.17387651803473045, -0.31779710200006395, -0.05273178706791511, -0.07512018911032277, -0.3888080148169517, -0.07322424612717245, 0.28357185811745633, 0.023591599823268423, -0.15012507910438574, -0.15798134219898938, -0.18619351366589146, 0.053468330459870796, 0.20471727323629585, -0.0713039829710401], ], dtype=jnp.float64).transpose() w3 = jnp.array([ [0.11691632192709213, -0.5809594599696845, 0.39381126545278994, -0.11458144569683075, 0.2915800947381822, -0.15832487965860276, 0.1369483690040619, -0.7025255535779401, -0.52684924123905, -0.4198672446947172, -0.06123437567379799, -0.7512464259031908, -0.7650905649264873, -0.39969769870438354, -0.04659053308762004, -0.2678712762974846], [-0.8767801130673146, 0.4601717900455377, -0.05132525856913018, 0.44855110521788133, -0.6449745416787377, -0.27889856392603646, -0.410660305614706, 0.5594501700581663, 0.317268176966285, 0.5867561684781535, -0.678346069837083, 0.24462363106285678, 0.3916154422007198, 0.6844529377238694, -0.5865824496546097, -0.3857274274616773], ], dtype=jnp.float64).transpose() b = jnp.array([ 0.255752943202295, 0.19702783673282653, 0.24858735401570572, -0.1252846548840331, 0.2815230793317552, 0.13979702661104412, 0.053620718304894587, 0.0033978894001451534, 0.3825015948978667, 0.03821449844202758, 0.09008965756403677, 0.19467940108203366, 0.31294371207869265, 0.18275503325976908, 0.571428363835027, 0.13578580279418276 ], dtype=jnp.float64) b1 = jnp.array([ 0.0012693931383079586, -0.01640847107615524, 0.030536316484409697, 0.13395179363202533, -0.028180473694846275, 0.03442145821752985, 0.12747046885575897, -0.06768319541892545, 0.1605257693365629, 0.13533733885832608, 0.06484267197449477, 0.06299802475582032, -0.04502310622483803, -0.06214246143576937, -0.09404026146645263, 0.042979795410095216 ], dtype=jnp.float64) b2 = jnp.array([ 0.12204247764854607, 0.08112763087584626, -0.00402684696986431, 0.037542507461046294, 0.0026271480327488086, -0.06454833917396394, -0.06081250143514959, 0.11725409888413155, 0.04075150077076115, 0.033639503573095274, -0.012716427102788417, 0.07539504190485471, 0.07803680531271559, 0.01777794957869979, -0.049570569756576074, -0.0606558939078927 ], dtype=jnp.float64) b3 = jnp.array([ -0.03701019424088373, 0.03701019424088411 ], dtype=jnp.float64) params = { 'linear': {'w': w, 'b': b,}, 'linear_1': {'w': w1, 'b': b1,}, 'linear_2': {'w': w2, 'b': b2,}, 'linear_3': {'w': w3, 'b': b3,}, } self.params = hk.data_structures.to_immutable_dict(params) @jax.jit def net_fn(x): mlp = hk.Sequential([ hk.Linear(16), jax.nn.swish, hk.Linear(16), jax.nn.swish, hk.Linear(16), jax.nn.swish, hk.Linear(2), jax.nn.softmax, ]) return mlp(x) self.net = hk.without_apply_rng(hk.transform(net_fn))