jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,15 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from typing import TYPE_CHECKING, Dict
9
+ from typing import TYPE_CHECKING
10
10
  from jaxtyping import Float
11
11
  import jax
12
12
  from jax import grad
13
13
  import jax.numpy as jnp
14
14
  import equinox as eqx
15
15
 
16
- from jinns.utils._pinn import PINN
17
- from jinns.utils._spinn import SPINN
16
+ from jinns.nn._pinn import PINN
17
+ from jinns.nn._spinn_mlp import SPINN
18
18
 
19
19
  from jinns.utils._utils import get_grid
20
20
  from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
@@ -29,10 +29,11 @@ from jinns.loss._operators import (
29
29
  _u_dot_nabla_times_u_fwd,
30
30
  )
31
31
 
32
- from jaxtyping import Array, Float
32
+ from jaxtyping import Array
33
33
 
34
34
  if TYPE_CHECKING:
35
- from jinns.parameters import Params, ParamsDict
35
+ from jinns.parameters import Params
36
+ from jinns.nn._abstract_pinn import AbstractPINN
36
37
 
37
38
 
38
39
  class FisherKPP(PDENonStatio):
@@ -54,10 +55,10 @@ class FisherKPP(PDENonStatio):
54
55
 
55
56
  def equation(
56
57
  self,
57
- t_x: Float[Array, "1+dim"],
58
- u: eqx.Module,
59
- params: Params,
60
- ) -> Float[Array, "1"]:
58
+ t_x: Float[Array, " 1+dim"],
59
+ u: AbstractPINN,
60
+ params: Params[Array],
61
+ ) -> Float[Array, " 1"]:
61
62
  r"""
62
63
  Evaluate the dynamic loss at $(t, x)$.
63
64
 
@@ -74,13 +75,14 @@ class FisherKPP(PDENonStatio):
74
75
  dictionaries: `eq_params` and `nn_params`, respectively the
75
76
  differential equation parameters and the neural network parameter
76
77
  """
78
+ assert u.eq_type != "ODE", "Cannot compute the loss for ODE PINNs"
77
79
  if isinstance(u, PINN):
78
80
  # Note that the last dim of u is nec. 1
79
81
  u_ = lambda t_x: u(t_x, params)[0]
80
82
 
81
83
  du_dt = grad(u_)(t_x)[0]
82
84
 
83
- lap = laplacian_rev(t_x, u, params, eq_type=u.eq_type)[..., None]
85
+ lap = laplacian_rev(t_x, u, params)[..., None]
84
86
 
85
87
  return du_dt + self.Tmax * (
86
88
  -params.eq_params["D"] * lap
@@ -96,7 +98,7 @@ class FisherKPP(PDENonStatio):
96
98
  (t_x,),
97
99
  (v0,),
98
100
  )
99
- lap = laplacian_fwd(t_x, u, params, eq_type=u.eq_type)
101
+ lap = laplacian_fwd(t_x, u, params)
100
102
 
101
103
  return du_dt + self.Tmax * (
102
104
  -params.eq_params["D"] * lap
@@ -108,14 +110,14 @@ class FisherKPP(PDENonStatio):
108
110
  class GeneralizedLotkaVolterra(ODE):
109
111
  r"""
110
112
  Return a dynamic loss from an equation of a Generalized Lotka Volterra
111
- system. Say we implement the equation for population $i$
113
+ system. Say we implement the system of equations, for several populations $i$
112
114
 
113
115
  $$
114
- \frac{\partial}{\partial t}u_i(t) = r_iu_i(t) - \sum_{j\neq i}\alpha_{ij}u_j(t)
115
- -\alpha_{i,i}u_i(t) + c_iu_i(t) + \sum_{j \neq i} c_ju_j(t)
116
+ \frac{\partial}{\partial t}u_i(t) = \alpha_iu_i(t)
117
+ -\sum_j\gamma_{j,i}u_j(t) - \beta_i\sum_{i'}u_{i'}(t), i\in\{1, 2, 3\}
116
118
  $$
117
- with $r_i$ the growth rate parameter, $c_i$ the carrying
118
- capacities and $\alpha_{ij}$ the interaction terms.
119
+
120
+ where $\alpha$ are the growth rates, $\gamma$ are the interactions terms and $\beta$ and the capacity terms.
119
121
 
120
122
  Parameters
121
123
  ----------
@@ -142,16 +144,12 @@ class GeneralizedLotkaVolterra(ODE):
142
144
  heterogeneity for no parameters.
143
145
  """
144
146
 
145
- # they should be static because they are list of strings
146
- key_main: list[str] = eqx.field(static=True)
147
- keys_other: list[str] = eqx.field(static=True)
148
-
149
147
  def equation(
150
148
  self,
151
- t: Float[Array, "1"],
152
- u_dict: Dict[str, eqx.Module],
153
- params_dict: ParamsDict,
154
- ) -> Float[Array, "1"]:
149
+ t: Float[Array, " 1"],
150
+ u: AbstractPINN,
151
+ params: Params[Array],
152
+ ) -> Float[Array, " 1"]:
155
153
  """
156
154
  Evaluate the dynamic loss at `t`.
157
155
  For stability we implement the dynamic loss in log space.
@@ -160,33 +158,23 @@ class GeneralizedLotkaVolterra(ODE):
160
158
  ---------
161
159
  t
162
160
  A time point
163
- u_dict
164
- A dictionary of PINNS. Must have the same keys as `params_dict`
165
- params_dict
166
- The dictionary of dictionaries of parameters of the model. Keys at
167
- top level are "nn_params" and "eq_params"
161
+ u
162
+ A vectorial PINN with as many outputs as there are populations
163
+ params
164
+ The parameters in a Params object
168
165
  """
169
- params_main = params_dict.extract_params(self.key_main)
170
-
171
- u = u_dict[self.key_main]
172
- # need to index with [0] since u output is nec (1,)
173
- du_dt = grad(lambda t: jnp.log(u(t, params_main)[0]), 0)(t)
174
- carrying_term = params_main.eq_params["carrying_capacity"] * u(t, params_main)
175
- # NOTE the following assumes interaction term with oneself is at idx 0
176
- interaction_terms = params_main.eq_params["interactions"][0] * u(t, params_main)
177
-
178
- # TODO write this for loop with tree_util functions?
179
- for i, k in enumerate(self.keys_other):
180
- params_k = params_dict.extract_params(k)
181
- carrying_term += params_main.eq_params["carrying_capacity"] * u_dict[k](
182
- t, params_k
183
- )
184
- interaction_terms += params_main.eq_params["interactions"][i + 1] * u_dict[
185
- k
186
- ](t, params_k)
187
-
188
- return du_dt + self.Tmax * (
189
- -params_main.eq_params["growth_rate"] - interaction_terms + carrying_term
166
+ du_dt = jax.jacrev(lambda t: jnp.log(u(t, params)))(t)
167
+ carrying_term = params.eq_params["carrying_capacities"] * jnp.sum(u(t, params))
168
+ interactions_terms = jax.tree.map(
169
+ lambda interactions_for_i: jnp.sum(
170
+ interactions_for_i * u(t, params).squeeze()
171
+ ),
172
+ params.eq_params["interactions"],
173
+ is_leaf=eqx.is_array,
174
+ )
175
+ interactions_terms = jnp.array([*(interactions_terms)])
176
+ return du_dt.squeeze() + self.Tmax * (
177
+ -params.eq_params["growth_rates"] + interactions_terms + carrying_term
190
178
  )
191
179
 
192
180
 
@@ -216,10 +204,10 @@ class BurgersEquation(PDENonStatio):
216
204
 
217
205
  def equation(
218
206
  self,
219
- t_x: Float[Array, "1+dim"],
220
- u: eqx.Module,
221
- params: Params,
222
- ) -> Float[Array, "1"]:
207
+ t_x: Float[Array, " 1+dim"],
208
+ u: AbstractPINN,
209
+ params: Params[Array],
210
+ ) -> Float[Array, " 1"]:
223
211
  r"""
224
212
  Evaluate the dynamic loss at :math:`(t,x)`.
225
213
 
@@ -312,10 +300,10 @@ class FPENonStatioLoss2D(PDENonStatio):
312
300
 
313
301
  def equation(
314
302
  self,
315
- t_x: Float[Array, "1+dim"],
316
- u: eqx.Module,
317
- params: Params,
318
- ) -> Float[Array, "1"]:
303
+ t_x: Float[Array, " 1+dim"],
304
+ u: AbstractPINN,
305
+ params: Params[Array],
306
+ ) -> Float[Array, " 1"]:
319
307
  r"""
320
308
  Evaluate the dynamic loss at $(t,\mathbf{x})$.
321
309
 
@@ -506,94 +494,36 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
506
494
  jnp.transpose(self.sigma_mat(x, eq_params)),
507
495
  )
508
496
  )
509
- return 0.5 * (
510
- jnp.matmul(
511
- self.sigma_mat(x, eq_params),
512
- jnp.transpose(self.sigma_mat(x, eq_params)),
513
- )[i, j]
497
+ return (
498
+ 0.5
499
+ * (
500
+ jnp.matmul(
501
+ self.sigma_mat(x, eq_params),
502
+ jnp.transpose(self.sigma_mat(x, eq_params)),
503
+ )[i, j]
504
+ )
514
505
  )
515
506
 
516
507
 
517
- class MassConservation2DStatio(PDEStatio):
508
+ class NavierStokesMassConservation2DStatio(PDEStatio):
518
509
  r"""
519
- Returns the so-called mass conservation equation.
510
+ Dynamic loss for the system of equations (stationary Navier Stokes in 2D,
511
+ mass conservation). For a 2D velocity field $\mathbf{u}$ and a 1D pressure
512
+ field $p$, this reads:
520
513
 
521
514
  $$
522
- \nabla \cdot \mathbf{u} = \frac{\partial}{\partial x}u(x,y) +
523
- \frac{\partial}{\partial y}u(x,y) = 0,
524
- $$
525
- where $u$ is a stationary function, i.e., it does not depend on
526
- $t$.
527
-
528
- Parameters
529
- ----------
530
- nn_key
531
- A dictionary key which identifies, in `u_dict` the PINN that
532
- appears in the mass conservation equation.
533
- eq_params_heterogeneity
534
- Default None. A dict with the keys being the same as in eq_params
535
- and the value being `time`, `space`, `both` or None which corresponds to
536
- the heterogeneity of a given parameter. A value can be missing, in
537
- this case there is no heterogeneity (=None). If
538
- eq_params_heterogeneity is None this means there is no
539
- heterogeneity for no parameters.
540
- """
541
-
542
- # an str field should be static (not a valid JAX type)
543
- nn_key: str = eqx.field(static=True)
544
-
545
- def equation(
546
- self,
547
- x: Float[Array, "dim"],
548
- u_dict: Dict[str, eqx.Module],
549
- params_dict: ParamsDict,
550
- ) -> Float[Array, "1"]:
551
- r"""
552
- Evaluate the dynamic loss at `\mathbf{x}`.
553
- For stability we implement the dynamic loss in log space.
554
-
555
- Parameters
556
- ---------
557
- x
558
- A point in $\Omega\subset\mathbb{R}^2$
559
- u_dict
560
- A dictionary of PINNs. Must have the same keys as `params_dict`
561
- params_dict
562
- The dictionary of dictionaries of parameters of the model.
563
- Typically, each sub-dictionary is a dictionary
564
- with keys: `eq_params` and `nn_params`, respectively the
565
- differential equation parameters and the neural network parameter.
566
- Must have the same keys as `u_dict`
567
- """
568
- params = params_dict.extract_params(self.nn_key)
569
-
570
- if isinstance(u_dict[self.nn_key], PINN):
571
- u = u_dict[self.nn_key]
572
-
573
- return divergence_rev(x, u, params)[..., None]
574
-
575
- if isinstance(u_dict[self.nn_key], SPINN):
576
- u = u_dict[self.nn_key]
577
-
578
- return divergence_fwd(x, u, params)[..., None]
579
- raise ValueError("u is not among the recognized types (PINN or SPINN)")
580
-
581
-
582
- class NavierStokes2DStatio(PDEStatio):
583
- r"""
584
- Return the dynamic loss for all the components of the stationary Navier Stokes
585
- equation which is a 2D vectorial PDE.
586
-
587
- $$
588
- (\mathbf{u}\cdot\nabla)\mathbf{u} + \frac{1}{\rho}\nabla p - \theta
589
- \nabla^2\mathbf{u}=0,
515
+ \begin{cases}
516
+ &(\mathbf{u}\cdot\nabla)\mathbf{u} + \frac{1}{\rho}\nabla p - \theta
517
+ \nabla^2\mathbf{u}=0,\\
518
+ &\nabla\cdot\mathbf{u}=0,
519
+ \end{cases}
590
520
  $$
591
521
 
592
522
  or, in 2D,
593
523
 
594
-
595
524
  $$
596
- \begin{pmatrix}u_x\frac{\partial}{\partial x} u_x +
525
+ \begin{cases}
526
+ &\begin{pmatrix}u_x\frac{\partial}{\partial x} u_x +
597
527
  u_y\frac{\partial}{\partial y} u_x, \\
598
528
  u_x\frac{\partial}{\partial x} u_y + u_y\frac{\partial}{\partial y} u_y \end{pmatrix} +
599
529
  \frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p, \\ \frac{\partial}{\partial y} p \end{pmatrix}
@@ -602,24 +532,15 @@ class NavierStokes2DStatio(PDEStatio):
602
532
  \frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2}
603
533
  u_x, \\
604
534
  \frac{\partial^2}{\partial x^2} u_y + \frac{\partial^2}{\partial y^2} u_y
605
- \end{pmatrix} = 0,
535
+ \end{pmatrix} = 0,\\
536
+ &\frac{\partial}{\partial x}u(x,y) +
537
+ \frac{\partial}{\partial y}u(x,y) = 0,
538
+ \end{cases}
606
539
  $$
607
540
  with $\theta$ the viscosity coefficient and $\rho$ the density coefficient.
608
541
 
609
542
  Parameters
610
543
  ----------
611
- u_key
612
- A dictionary key which indices the NN u in `u_dict`
613
- the PINN with the role of the velocity in the equation.
614
- Its input is bimensional (points in $\Omega\subset\mathbb{R}^2$).
615
- Its output is bimensional as it represents a velocity vector
616
- field
617
- p_key
618
- A dictionary key which indices the NN p in `u_dict`
619
- the PINN with the role of the pressure in the equation.
620
- Its input is bimensional (points in $\Omega\subset\mathbb{R}^2).
621
- Its output is unidimensional as it represents a pressure scalar
622
- field
623
544
  eq_params_heterogeneity
624
545
  Default None. A dict with the keys being the same as in eq_params
625
546
  and the value being `time`, `space`, `both` or None which corresponds to
@@ -629,15 +550,12 @@ class NavierStokes2DStatio(PDEStatio):
629
550
  heterogeneity for no parameters.
630
551
  """
631
552
 
632
- u_key: str = eqx.field(static=True)
633
- p_key: str = eqx.field(static=True)
634
-
635
553
  def equation(
636
554
  self,
637
- x: Float[Array, "dim"],
638
- u_dict: Dict[str, eqx.Module],
639
- params_dict: ParamsDict,
640
- ) -> Float[Array, "1"]:
555
+ x: Float[Array, " dim"],
556
+ u_p: AbstractPINN,
557
+ params: Params[Array],
558
+ ) -> Float[Array, " 3"]:
641
559
  r"""
642
560
  Evaluate the dynamic loss at `x`.
643
561
  For stability we implement the dynamic loss in log space.
@@ -646,72 +564,81 @@ class NavierStokes2DStatio(PDEStatio):
646
564
  ---------
647
565
  x
648
566
  A point in $\Omega\subset\mathbb{R}^2$
649
- u_dict
650
- A dictionary of PINNs. Must have the same keys as `params_dict`
651
- params_dict
652
- The dictionary of dictionaries of parameters of the model.
653
- Typically, each sub-dictionary is a dictionary
654
- with keys: `eq_params` and `nn_params`, respectively the
655
- differential equation parameters and the neural network parameter.
656
- Must have the same keys as `u_dict`
567
+ u_p
568
+ A PINN with 3 outputs of the 2D velocity field and the 1D pressure
569
+ field
570
+ params
571
+ The parameters in a Params object
657
572
  """
658
- u_params = params_dict.extract_params(self.u_key)
659
- p_params = params_dict.extract_params(self.p_key)
660
-
661
- if isinstance(u_dict[self.u_key], PINN):
662
- u = u_dict[self.u_key]
573
+ assert u_p.eq_type != "ODE", "Cannot compute the loss for ODE PINNs"
574
+ if isinstance(u_p, PINN):
575
+ u = lambda x, params: u_p(x, params)[0:2]
576
+ p = lambda x, params: u_p(x, params)[2:3]
663
577
 
664
- u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(x, u, u_params)
578
+ # NAVIER STOKES
579
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(x, u, params)
665
580
 
666
- p = lambda x: u_dict[self.p_key](x, p_params)
667
- jac_p = jax.jacrev(p, 0)(x) # compute the gradient
581
+ jac_p = jax.jacrev(p, 0)(x, params) # compute the gradient
668
582
 
669
- vec_laplacian_u = vectorial_laplacian_rev(x, u, u_params, dim_out=2)
583
+ vec_laplacian_u = vectorial_laplacian_rev(
584
+ x, u, params, dim_out=2, eq_type=u_p.eq_type
585
+ )
670
586
 
671
587
  # dynamic loss on x axis
672
588
  result_x = (
673
589
  u_dot_nabla_x_u[0]
674
- + 1 / params_dict.eq_params["rho"] * jac_p[0, 0]
675
- - params_dict.eq_params["nu"] * vec_laplacian_u[0]
590
+ + 1 / params.eq_params["rho"] * jac_p[0, 0]
591
+ - params.eq_params["nu"] * vec_laplacian_u[0]
676
592
  )
677
593
 
678
594
  # dynamic loss on y axis
679
595
  result_y = (
680
596
  u_dot_nabla_x_u[1]
681
- + 1 / params_dict.eq_params["rho"] * jac_p[0, 1]
682
- - params_dict.eq_params["nu"] * vec_laplacian_u[1]
597
+ + 1 / params.eq_params["rho"] * jac_p[0, 1]
598
+ - params.eq_params["nu"] * vec_laplacian_u[1]
683
599
  )
684
600
 
685
- # output is 2D
686
- return jnp.stack([result_x, result_y], axis=-1)
601
+ # MASS CONVERVATION
687
602
 
688
- if isinstance(u_dict[self.u_key], SPINN):
689
- u = u_dict[self.u_key]
603
+ mc = divergence_rev(x, u, params, eq_type=u_p.eq_type)
690
604
 
691
- u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(x, u, u_params)
605
+ # output is 3D
606
+ if mc.ndim == 0 and not result_x.ndim == 0:
607
+ mc = mc[None]
608
+ return jnp.stack([result_x, result_y, mc], axis=-1)
692
609
 
693
- p = lambda x: u_dict[self.p_key](x, p_params)
610
+ if isinstance(u_p, SPINN):
611
+ u = lambda x, params: u_p(x, params)[..., 0:2]
612
+ p = lambda x: u_p(x, params)[..., 2:3]
613
+
614
+ # NAVIER STOKES
615
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(x, u, params)
694
616
 
695
617
  tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
696
618
  _, dp_dx = jax.jvp(p, (x,), (tangent_vec_0,))
697
619
  tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
698
620
  _, dp_dy = jax.jvp(p, (x,), (tangent_vec_1,))
699
621
 
700
- vec_laplacian_u = vectorial_laplacian_fwd(x, u, u_params, dim_out=2)
622
+ vec_laplacian_u = vectorial_laplacian_fwd(
623
+ x, u, params, dim_out=2, eq_type=u_p.eq_type
624
+ )
701
625
 
702
626
  # dynamic loss on x axis
703
627
  result_x = (
704
628
  u_dot_nabla_x_u[..., 0]
705
- + 1 / params_dict.eq_params["rho"] * dp_dx.squeeze()
706
- - params_dict.eq_params["nu"] * vec_laplacian_u[..., 0]
629
+ + 1 / params.eq_params["rho"] * dp_dx.squeeze()
630
+ - params.eq_params["nu"] * vec_laplacian_u[..., 0]
707
631
  )
708
632
  # dynamic loss on y axis
709
633
  result_y = (
710
634
  u_dot_nabla_x_u[..., 1]
711
- + 1 / params_dict.eq_params["rho"] * dp_dy.squeeze()
712
- - params_dict.eq_params["nu"] * vec_laplacian_u[..., 1]
635
+ + 1 / params.eq_params["rho"] * dp_dy.squeeze()
636
+ - params.eq_params["nu"] * vec_laplacian_u[..., 1]
713
637
  )
714
638
 
715
- # output is 2D
716
- return jnp.stack([result_x, result_y], axis=-1)
639
+ # MASS CONVERVATION
640
+ mc = divergence_fwd(x, u, params, eq_type=u_p.eq_type)[..., None]
641
+
642
+ # output is (..., 3)
643
+ return jnp.stack([result_x[..., None], result_y[..., None], mc], axis=-1)
717
644
  raise ValueError("u is not among the recognized types (PINN or SPINN)")