jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +7 -7
  2. jinns/data/_AbstractDataGenerator.py +1 -1
  3. jinns/data/_Batchs.py +47 -13
  4. jinns/data/_CubicMeshPDENonStatio.py +203 -54
  5. jinns/data/_CubicMeshPDEStatio.py +190 -54
  6. jinns/data/_DataGeneratorODE.py +48 -22
  7. jinns/data/_DataGeneratorObservations.py +75 -32
  8. jinns/data/_DataGeneratorParameter.py +152 -101
  9. jinns/data/__init__.py +2 -1
  10. jinns/data/_utils.py +22 -10
  11. jinns/loss/_DynamicLoss.py +21 -20
  12. jinns/loss/_DynamicLossAbstract.py +51 -36
  13. jinns/loss/_LossODE.py +210 -191
  14. jinns/loss/_LossPDE.py +441 -368
  15. jinns/loss/_abstract_loss.py +60 -25
  16. jinns/loss/_loss_components.py +4 -25
  17. jinns/loss/_loss_utils.py +23 -0
  18. jinns/loss/_loss_weight_updates.py +6 -7
  19. jinns/loss/_loss_weights.py +34 -35
  20. jinns/nn/_abstract_pinn.py +0 -2
  21. jinns/nn/_hyperpinn.py +34 -23
  22. jinns/nn/_mlp.py +5 -4
  23. jinns/nn/_pinn.py +1 -16
  24. jinns/nn/_ppinn.py +5 -16
  25. jinns/nn/_save_load.py +11 -4
  26. jinns/nn/_spinn.py +1 -16
  27. jinns/nn/_spinn_mlp.py +5 -5
  28. jinns/nn/_utils.py +33 -38
  29. jinns/parameters/__init__.py +3 -1
  30. jinns/parameters/_derivative_keys.py +99 -41
  31. jinns/parameters/_params.py +58 -25
  32. jinns/solver/_solve.py +14 -8
  33. jinns/utils/_DictToModuleMeta.py +66 -0
  34. jinns/utils/_ItemizableModule.py +19 -0
  35. jinns/utils/__init__.py +2 -1
  36. jinns/utils/_types.py +25 -15
  37. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  38. jinns-1.6.0.dist-info/RECORD +57 -0
  39. jinns-1.5.0.dist-info/RECORD +0 -55
  40. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  41. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  42. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
@@ -85,9 +85,9 @@ class FisherKPP(PDENonStatio):
85
85
  lap = laplacian_rev(t_x, u, params)[..., None]
86
86
 
87
87
  return du_dt + self.Tmax * (
88
- -params.eq_params["D"] * lap
88
+ -params.eq_params.D * lap
89
89
  - u(t_x, params)
90
- * (params.eq_params["r"] - params.eq_params["g"] * u(t_x, params))
90
+ * (params.eq_params.r - params.eq_params.g * u(t_x, params))
91
91
  )
92
92
  if isinstance(u, SPINN):
93
93
  s = jnp.zeros((1, self.dim_x + 1))
@@ -101,8 +101,8 @@ class FisherKPP(PDENonStatio):
101
101
  lap = laplacian_fwd(t_x, u, params)
102
102
 
103
103
  return du_dt + self.Tmax * (
104
- -params.eq_params["D"] * lap
105
- - u_tx * (params.eq_params["r"] - params.eq_params["g"] * u_tx)
104
+ -params.eq_params.D * lap
105
+ - u_tx * (params.eq_params.r - params.eq_params.g * u_tx)
106
106
  )
107
107
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
108
108
 
@@ -164,17 +164,17 @@ class GeneralizedLotkaVolterra(ODE):
164
164
  The parameters in a Params object
165
165
  """
166
166
  du_dt = jax.jacrev(lambda t: jnp.log(u(t, params)))(t)
167
- carrying_term = params.eq_params["carrying_capacities"] * jnp.sum(u(t, params))
167
+ carrying_term = params.eq_params.carrying_capacities * jnp.sum(u(t, params))
168
168
  interactions_terms = jax.tree.map(
169
169
  lambda interactions_for_i: jnp.sum(
170
170
  interactions_for_i * u(t, params).squeeze()
171
171
  ),
172
- params.eq_params["interactions"],
172
+ params.eq_params.interactions,
173
173
  is_leaf=eqx.is_array,
174
174
  )
175
175
  interactions_terms = jnp.array([*(interactions_terms)])
176
176
  return du_dt.squeeze() + self.Tmax * (
177
- -params.eq_params["growth_rates"] + interactions_terms + carrying_term
177
+ -params.eq_params.growth_rates + interactions_terms + carrying_term
178
178
  )
179
179
 
180
180
 
@@ -229,7 +229,7 @@ class BurgersEquation(PDENonStatio):
229
229
 
230
230
  return du_dtx_values[0:1] + self.Tmax * (
231
231
  u_(t_x) * du_dtx_values[1:2]
232
- - params.eq_params["nu"] * d2u_dx_dtx(t_x)[1:2]
232
+ - params.eq_params.nu * d2u_dx_dtx(t_x)[1:2]
233
233
  )
234
234
 
235
235
  if isinstance(u, SPINN):
@@ -258,7 +258,7 @@ class BurgersEquation(PDENonStatio):
258
258
  )[1]
259
259
  _, d2u_dx2 = jax.jvp(du_dx_fun, (t_x,), (v1,))
260
260
  # Note that ones_like(x) works because x is Bx1 !
261
- return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params["nu"] * d2u_dx2)
261
+ return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params.nu * d2u_dx2)
262
262
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
263
263
 
264
264
 
@@ -458,7 +458,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
458
458
  eq_params
459
459
  A dictionary containing the equation parameters
460
460
  """
461
- return eq_params["alpha"] * (eq_params["mu"] - x)
461
+ return eq_params.alpha * (eq_params.mu - x)
462
462
 
463
463
  def sigma_mat(self, x, eq_params):
464
464
  r"""
@@ -473,7 +473,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
473
473
  A dictionary containing the equation parameters
474
474
  """
475
475
 
476
- return jnp.diag(eq_params["sigma"])
476
+ return jnp.diag(eq_params.sigma)
477
477
 
478
478
  def diffusion(self, x, eq_params, i=None, j=None):
479
479
  r"""
@@ -587,15 +587,15 @@ class NavierStokesMassConservation2DStatio(PDEStatio):
587
587
  # dynamic loss on x axis
588
588
  result_x = (
589
589
  u_dot_nabla_x_u[0]
590
- + 1 / params.eq_params["rho"] * jac_p[0, 0]
591
- - params.eq_params["nu"] * vec_laplacian_u[0]
590
+ + 1 / params.eq_params.rho * jac_p[0, 0]
591
+ - params.eq_params.nu * vec_laplacian_u[0]
592
592
  )
593
593
 
594
594
  # dynamic loss on y axis
595
595
  result_y = (
596
596
  u_dot_nabla_x_u[1]
597
- + 1 / params.eq_params["rho"] * jac_p[0, 1]
598
- - params.eq_params["nu"] * vec_laplacian_u[1]
597
+ + 1 / params.eq_params.rho * jac_p[0, 1]
598
+ - params.eq_params.nu * vec_laplacian_u[1]
599
599
  )
600
600
 
601
601
  # MASS CONVERVATION
@@ -605,7 +605,8 @@ class NavierStokesMassConservation2DStatio(PDEStatio):
605
605
  # output is 3D
606
606
  if mc.ndim == 0 and not result_x.ndim == 0:
607
607
  mc = mc[None]
608
- return jnp.stack([result_x, result_y, mc], axis=-1)
608
+
609
+ return jnp.stack([result_x, result_y, mc], axis=-1).squeeze()
609
610
 
610
611
  if isinstance(u_p, SPINN):
611
612
  u = lambda x, params: u_p(x, params)[..., 0:2]
@@ -626,14 +627,14 @@ class NavierStokesMassConservation2DStatio(PDEStatio):
626
627
  # dynamic loss on x axis
627
628
  result_x = (
628
629
  u_dot_nabla_x_u[..., 0]
629
- + 1 / params.eq_params["rho"] * dp_dx.squeeze()
630
- - params.eq_params["nu"] * vec_laplacian_u[..., 0]
630
+ + 1 / params.eq_params.rho * dp_dx.squeeze()
631
+ - params.eq_params.nu * vec_laplacian_u[..., 0]
631
632
  )
632
633
  # dynamic loss on y axis
633
634
  result_y = (
634
635
  u_dot_nabla_x_u[..., 1]
635
- + 1 / params.eq_params["rho"] * dp_dy.squeeze()
636
- - params.eq_params["nu"] * vec_laplacian_u[..., 1]
636
+ + 1 / params.eq_params.rho * dp_dy.squeeze()
637
+ - params.eq_params.nu * vec_laplacian_u[..., 1]
637
638
  )
638
639
 
639
640
  # MASS CONVERVATION
@@ -9,10 +9,13 @@ from __future__ import (
9
9
  import warnings
10
10
  import abc
11
11
  from functools import partial
12
- from typing import Callable, TYPE_CHECKING, ClassVar, Generic, TypeVar
12
+ from dataclasses import InitVar
13
+ from typing import Callable, TYPE_CHECKING, ClassVar, Generic, TypeVar, Any
13
14
  import equinox as eqx
14
- from jaxtyping import Float, Array
15
+ from jaxtyping import Float, Array, PyTree
16
+ import jax
15
17
  import jax.numpy as jnp
18
+ from jinns.parameters._params import EqParams
16
19
 
17
20
 
18
21
  # See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
@@ -59,7 +62,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
59
62
  Tmax needs to be given when the PINN time input is normalized in
60
63
  [0, 1], ie. we have performed renormalization of the differential
61
64
  equation
62
- eq_params_heterogeneity : dict[str, Callable | None], default=None
65
+ eq_params_heterogeneity : dict[str, Callable[[InputDim, AbstractPINN, Params[Array]], Array] | None], default=None
63
66
  A dict with the same keys as eq_params and the value being either None
64
67
  (no heterogeneity) or a function which encodes for the spatio-temporal
65
68
  heterogeneity of the parameter.
@@ -71,7 +74,10 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
71
74
  `eq_params` too.
72
75
  A value can be missing, in this case there is no heterogeneity (=None).
73
76
  Default None, meaning there is no heterogeneity in the equation
74
- parameters.
77
+ parameters. Note that since 1.6.0, this is handled inernally as a
78
+ `PyTree[Callable[[InputDim, AbstractPINN, Params[Array]], Array] |
79
+ None] | None` (`Params.eq_params` is not a dict
80
+ anymore).
75
81
  vectorial_dyn_loss_ponderation : Float[Array, " dim"], default=None
76
82
  Add a different ponderation weight to each of the dimension to the
77
83
  dynamic loss. This array must have the same dimension as the output of
@@ -86,40 +92,49 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
86
92
 
87
93
  _eq_type = AbstractClassVar[str] # class variable denoting the type of
88
94
  # differential equation
89
- Tmax: Float = eqx.field(kw_only=True, default=1)
90
- eq_params_heterogeneity: dict[str, Callable | None] | None = eqx.field(
91
- kw_only=True, default=None, static=True
92
- )
95
+ Tmax: float = eqx.field(kw_only=True, default=1)
96
+ eq_params_heterogeneity: (
97
+ PyTree[Callable[[InputDim, AbstractPINN, Params[Array]], Array] | None] | None
98
+ ) = eqx.field(kw_only=True, default=None, static=True)
93
99
  vectorial_dyn_loss_ponderation: Float[Array, " dim"] | None = eqx.field(
94
- kw_only=True, default=None
100
+ kw_only=True, default_factory=lambda: jnp.array(1.0)
95
101
  )
96
-
97
- def __post_init__(self):
98
- if self.vectorial_dyn_loss_ponderation is None:
99
- self.vectorial_dyn_loss_ponderation = jnp.array(1.0)
102
+ params: InitVar[Params[Array]] = eqx.field(default=None)
103
+
104
+ def __post_init__(self, params: Params[Array] | None = None):
105
+ if isinstance(self.eq_params_heterogeneity, dict): # type: ignore
106
+ # we cannot use the same converter as in Params.eq_params
107
+ # we don't want to create a new type but use the same type as
108
+ # Params.eq_params which already exists.
109
+ if params is None:
110
+ raise ValueError(
111
+ "When `self.eq_params_heterogeneity` is "
112
+ "provided, `params` must be specified at init"
113
+ )
114
+ self.eq_params_heterogeneity = EqParams(
115
+ self.eq_params_heterogeneity,
116
+ "EqParams", # type: ignore
117
+ )
100
118
 
101
119
  def _eval_heterogeneous_parameters(
102
120
  self,
103
121
  inputs: InputDim,
104
122
  u: AbstractPINN,
105
123
  params: Params[Array],
106
- eq_params_heterogeneity: dict[str, Callable | None] | None = None,
107
- ) -> dict[str, Array]:
108
- eq_params_ = {}
124
+ eq_params_heterogeneity: PyTree[
125
+ Callable[[InputDim, AbstractPINN, Params[Array]], Array] | None
126
+ ]
127
+ | None = None,
128
+ ) -> PyTree[Array]:
109
129
  if eq_params_heterogeneity is None:
110
130
  return params.eq_params
111
-
112
- for k, p in params.eq_params.items():
113
- try:
114
- if eq_params_heterogeneity[k] is not None:
115
- eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params) # type: ignore don't know why pyright says
116
- # eq_params_heterogeneity[k] can be None here
117
- else:
118
- eq_params_[k] = p
119
- except KeyError:
120
- # we authorize missing eq_params_heterogeneity key
121
- # if its heterogeneity is None anyway
122
- eq_params_[k] = p
131
+ eq_params_ = jax.tree.map(
132
+ lambda p, fun: ( # type: ignore
133
+ fun(inputs, u, params) if fun is not None else p
134
+ ),
135
+ params.eq_params,
136
+ eq_params_heterogeneity,
137
+ )
123
138
  return eq_params_
124
139
 
125
140
  @partial(_decorator_heteregeneous_params)
@@ -128,7 +143,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
128
143
  inputs: InputDim,
129
144
  u: AbstractPINN,
130
145
  params: Params[Array],
131
- ) -> float:
146
+ ) -> Float[Array, " eq_dim"]:
132
147
  evaluation = self.vectorial_dyn_loss_ponderation * self.equation(
133
148
  inputs, u, params
134
149
  )
@@ -147,7 +162,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
147
162
  return evaluation
148
163
 
149
164
  @abc.abstractmethod
150
- def equation(self, *args, **kwargs):
165
+ def equation(self, *args: Any, **kwargs: Any) -> Float[Array, " eq_dim"]:
151
166
  # TO IMPLEMENT
152
167
  # Point-wise evaluation of the differential equation N[u](.)
153
168
  raise NotImplementedError("You should implement your equation.")
@@ -183,7 +198,7 @@ class ODE(DynamicLoss[Float[Array, " 1"]]):
183
198
  @abc.abstractmethod
184
199
  def equation(
185
200
  self, t: Float[Array, " 1"], u: AbstractPINN, params: Params[Array]
186
- ) -> float:
201
+ ) -> Float[Array, " eq_dim"]:
187
202
  r"""
188
203
  The differential operator defining the ODE.
189
204
 
@@ -202,7 +217,7 @@ class ODE(DynamicLoss[Float[Array, " 1"]]):
202
217
 
203
218
  Returns
204
219
  -------
205
- float
220
+ Float[Array, "eq_dim"]
206
221
  The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t)$ evaluated at point `t`.
207
222
 
208
223
  Raises
@@ -242,7 +257,7 @@ class PDEStatio(DynamicLoss[Float[Array, " dim"]]):
242
257
  @abc.abstractmethod
243
258
  def equation(
244
259
  self, x: Float[Array, " dim"], u: AbstractPINN, params: Params[Array]
245
- ) -> float:
260
+ ) -> Float[Array, " eq_dim"]:
246
261
  r"""The differential operator defining the stationnary PDE.
247
262
 
248
263
  !!! warning
@@ -260,7 +275,7 @@ class PDEStatio(DynamicLoss[Float[Array, " dim"]]):
260
275
 
261
276
  Returns
262
277
  -------
263
- float
278
+ Float[Array, "eq_dim"]
264
279
  The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](x)$ evaluated at point `x`.
265
280
 
266
281
  Raises
@@ -303,7 +318,7 @@ class PDENonStatio(DynamicLoss[Float[Array, " 1 + dim"]]):
303
318
  t_x: Float[Array, " 1 + dim"],
304
319
  u: AbstractPINN,
305
320
  params: Params[Array],
306
- ) -> float:
321
+ ) -> Float[Array, " eq_dim"]:
307
322
  r"""The differential operator defining the non-stationnary PDE.
308
323
 
309
324
  !!! warning
@@ -320,7 +335,7 @@ class PDENonStatio(DynamicLoss[Float[Array, " 1 + dim"]]):
320
335
  The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
321
336
  Returns
322
337
  -------
323
- float
338
+ Float[Array, "eq_dim"]
324
339
  The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t, x)$ evaluated at point `(t, x)`.
325
340
 
326
341