jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_DynamicLoss.py
CHANGED
|
@@ -85,9 +85,9 @@ class FisherKPP(PDENonStatio):
|
|
|
85
85
|
lap = laplacian_rev(t_x, u, params)[..., None]
|
|
86
86
|
|
|
87
87
|
return du_dt + self.Tmax * (
|
|
88
|
-
-params.eq_params
|
|
88
|
+
-params.eq_params.D * lap
|
|
89
89
|
- u(t_x, params)
|
|
90
|
-
* (params.eq_params
|
|
90
|
+
* (params.eq_params.r - params.eq_params.g * u(t_x, params))
|
|
91
91
|
)
|
|
92
92
|
if isinstance(u, SPINN):
|
|
93
93
|
s = jnp.zeros((1, self.dim_x + 1))
|
|
@@ -101,8 +101,8 @@ class FisherKPP(PDENonStatio):
|
|
|
101
101
|
lap = laplacian_fwd(t_x, u, params)
|
|
102
102
|
|
|
103
103
|
return du_dt + self.Tmax * (
|
|
104
|
-
-params.eq_params
|
|
105
|
-
- u_tx * (params.eq_params
|
|
104
|
+
-params.eq_params.D * lap
|
|
105
|
+
- u_tx * (params.eq_params.r - params.eq_params.g * u_tx)
|
|
106
106
|
)
|
|
107
107
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
108
108
|
|
|
@@ -164,17 +164,17 @@ class GeneralizedLotkaVolterra(ODE):
|
|
|
164
164
|
The parameters in a Params object
|
|
165
165
|
"""
|
|
166
166
|
du_dt = jax.jacrev(lambda t: jnp.log(u(t, params)))(t)
|
|
167
|
-
carrying_term = params.eq_params
|
|
167
|
+
carrying_term = params.eq_params.carrying_capacities * jnp.sum(u(t, params))
|
|
168
168
|
interactions_terms = jax.tree.map(
|
|
169
169
|
lambda interactions_for_i: jnp.sum(
|
|
170
170
|
interactions_for_i * u(t, params).squeeze()
|
|
171
171
|
),
|
|
172
|
-
params.eq_params
|
|
172
|
+
params.eq_params.interactions,
|
|
173
173
|
is_leaf=eqx.is_array,
|
|
174
174
|
)
|
|
175
175
|
interactions_terms = jnp.array([*(interactions_terms)])
|
|
176
176
|
return du_dt.squeeze() + self.Tmax * (
|
|
177
|
-
-params.eq_params
|
|
177
|
+
-params.eq_params.growth_rates + interactions_terms + carrying_term
|
|
178
178
|
)
|
|
179
179
|
|
|
180
180
|
|
|
@@ -229,7 +229,7 @@ class BurgersEquation(PDENonStatio):
|
|
|
229
229
|
|
|
230
230
|
return du_dtx_values[0:1] + self.Tmax * (
|
|
231
231
|
u_(t_x) * du_dtx_values[1:2]
|
|
232
|
-
- params.eq_params
|
|
232
|
+
- params.eq_params.nu * d2u_dx_dtx(t_x)[1:2]
|
|
233
233
|
)
|
|
234
234
|
|
|
235
235
|
if isinstance(u, SPINN):
|
|
@@ -258,7 +258,7 @@ class BurgersEquation(PDENonStatio):
|
|
|
258
258
|
)[1]
|
|
259
259
|
_, d2u_dx2 = jax.jvp(du_dx_fun, (t_x,), (v1,))
|
|
260
260
|
# Note that ones_like(x) works because x is Bx1 !
|
|
261
|
-
return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params
|
|
261
|
+
return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params.nu * d2u_dx2)
|
|
262
262
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
263
263
|
|
|
264
264
|
|
|
@@ -458,7 +458,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
458
458
|
eq_params
|
|
459
459
|
A dictionary containing the equation parameters
|
|
460
460
|
"""
|
|
461
|
-
return eq_params
|
|
461
|
+
return eq_params.alpha * (eq_params.mu - x)
|
|
462
462
|
|
|
463
463
|
def sigma_mat(self, x, eq_params):
|
|
464
464
|
r"""
|
|
@@ -473,7 +473,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
473
473
|
A dictionary containing the equation parameters
|
|
474
474
|
"""
|
|
475
475
|
|
|
476
|
-
return jnp.diag(eq_params
|
|
476
|
+
return jnp.diag(eq_params.sigma)
|
|
477
477
|
|
|
478
478
|
def diffusion(self, x, eq_params, i=None, j=None):
|
|
479
479
|
r"""
|
|
@@ -587,15 +587,15 @@ class NavierStokesMassConservation2DStatio(PDEStatio):
|
|
|
587
587
|
# dynamic loss on x axis
|
|
588
588
|
result_x = (
|
|
589
589
|
u_dot_nabla_x_u[0]
|
|
590
|
-
+ 1 / params.eq_params
|
|
591
|
-
- params.eq_params
|
|
590
|
+
+ 1 / params.eq_params.rho * jac_p[0, 0]
|
|
591
|
+
- params.eq_params.nu * vec_laplacian_u[0]
|
|
592
592
|
)
|
|
593
593
|
|
|
594
594
|
# dynamic loss on y axis
|
|
595
595
|
result_y = (
|
|
596
596
|
u_dot_nabla_x_u[1]
|
|
597
|
-
+ 1 / params.eq_params
|
|
598
|
-
- params.eq_params
|
|
597
|
+
+ 1 / params.eq_params.rho * jac_p[0, 1]
|
|
598
|
+
- params.eq_params.nu * vec_laplacian_u[1]
|
|
599
599
|
)
|
|
600
600
|
|
|
601
601
|
# MASS CONVERVATION
|
|
@@ -605,7 +605,8 @@ class NavierStokesMassConservation2DStatio(PDEStatio):
|
|
|
605
605
|
# output is 3D
|
|
606
606
|
if mc.ndim == 0 and not result_x.ndim == 0:
|
|
607
607
|
mc = mc[None]
|
|
608
|
-
|
|
608
|
+
|
|
609
|
+
return jnp.stack([result_x, result_y, mc], axis=-1).squeeze()
|
|
609
610
|
|
|
610
611
|
if isinstance(u_p, SPINN):
|
|
611
612
|
u = lambda x, params: u_p(x, params)[..., 0:2]
|
|
@@ -626,14 +627,14 @@ class NavierStokesMassConservation2DStatio(PDEStatio):
|
|
|
626
627
|
# dynamic loss on x axis
|
|
627
628
|
result_x = (
|
|
628
629
|
u_dot_nabla_x_u[..., 0]
|
|
629
|
-
+ 1 / params.eq_params
|
|
630
|
-
- params.eq_params
|
|
630
|
+
+ 1 / params.eq_params.rho * dp_dx.squeeze()
|
|
631
|
+
- params.eq_params.nu * vec_laplacian_u[..., 0]
|
|
631
632
|
)
|
|
632
633
|
# dynamic loss on y axis
|
|
633
634
|
result_y = (
|
|
634
635
|
u_dot_nabla_x_u[..., 1]
|
|
635
|
-
+ 1 / params.eq_params
|
|
636
|
-
- params.eq_params
|
|
636
|
+
+ 1 / params.eq_params.rho * dp_dy.squeeze()
|
|
637
|
+
- params.eq_params.nu * vec_laplacian_u[..., 1]
|
|
637
638
|
)
|
|
638
639
|
|
|
639
640
|
# MASS CONVERVATION
|
|
@@ -9,10 +9,13 @@ from __future__ import (
|
|
|
9
9
|
import warnings
|
|
10
10
|
import abc
|
|
11
11
|
from functools import partial
|
|
12
|
-
from
|
|
12
|
+
from dataclasses import InitVar
|
|
13
|
+
from typing import Callable, TYPE_CHECKING, ClassVar, Generic, TypeVar, Any
|
|
13
14
|
import equinox as eqx
|
|
14
|
-
from jaxtyping import Float, Array
|
|
15
|
+
from jaxtyping import Float, Array, PyTree
|
|
16
|
+
import jax
|
|
15
17
|
import jax.numpy as jnp
|
|
18
|
+
from jinns.parameters._params import EqParams
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
# See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
|
|
@@ -59,7 +62,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
|
|
|
59
62
|
Tmax needs to be given when the PINN time input is normalized in
|
|
60
63
|
[0, 1], ie. we have performed renormalization of the differential
|
|
61
64
|
equation
|
|
62
|
-
eq_params_heterogeneity : dict[str, Callable | None], default=None
|
|
65
|
+
eq_params_heterogeneity : dict[str, Callable[[InputDim, AbstractPINN, Params[Array]], Array] | None], default=None
|
|
63
66
|
A dict with the same keys as eq_params and the value being either None
|
|
64
67
|
(no heterogeneity) or a function which encodes for the spatio-temporal
|
|
65
68
|
heterogeneity of the parameter.
|
|
@@ -71,7 +74,10 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
|
|
|
71
74
|
`eq_params` too.
|
|
72
75
|
A value can be missing, in this case there is no heterogeneity (=None).
|
|
73
76
|
Default None, meaning there is no heterogeneity in the equation
|
|
74
|
-
parameters.
|
|
77
|
+
parameters. Note that since 1.6.0, this is handled inernally as a
|
|
78
|
+
`PyTree[Callable[[InputDim, AbstractPINN, Params[Array]], Array] |
|
|
79
|
+
None] | None` (`Params.eq_params` is not a dict
|
|
80
|
+
anymore).
|
|
75
81
|
vectorial_dyn_loss_ponderation : Float[Array, " dim"], default=None
|
|
76
82
|
Add a different ponderation weight to each of the dimension to the
|
|
77
83
|
dynamic loss. This array must have the same dimension as the output of
|
|
@@ -86,40 +92,49 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
|
|
|
86
92
|
|
|
87
93
|
_eq_type = AbstractClassVar[str] # class variable denoting the type of
|
|
88
94
|
# differential equation
|
|
89
|
-
Tmax:
|
|
90
|
-
eq_params_heterogeneity:
|
|
91
|
-
|
|
92
|
-
)
|
|
95
|
+
Tmax: float = eqx.field(kw_only=True, default=1)
|
|
96
|
+
eq_params_heterogeneity: (
|
|
97
|
+
PyTree[Callable[[InputDim, AbstractPINN, Params[Array]], Array] | None] | None
|
|
98
|
+
) = eqx.field(kw_only=True, default=None, static=True)
|
|
93
99
|
vectorial_dyn_loss_ponderation: Float[Array, " dim"] | None = eqx.field(
|
|
94
|
-
kw_only=True,
|
|
100
|
+
kw_only=True, default_factory=lambda: jnp.array(1.0)
|
|
95
101
|
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
102
|
+
params: InitVar[Params[Array]] = eqx.field(default=None)
|
|
103
|
+
|
|
104
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
105
|
+
if isinstance(self.eq_params_heterogeneity, dict): # type: ignore
|
|
106
|
+
# we cannot use the same converter as in Params.eq_params
|
|
107
|
+
# we don't want to create a new type but use the same type as
|
|
108
|
+
# Params.eq_params which already exists.
|
|
109
|
+
if params is None:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"When `self.eq_params_heterogeneity` is "
|
|
112
|
+
"provided, `params` must be specified at init"
|
|
113
|
+
)
|
|
114
|
+
self.eq_params_heterogeneity = EqParams(
|
|
115
|
+
self.eq_params_heterogeneity,
|
|
116
|
+
"EqParams", # type: ignore
|
|
117
|
+
)
|
|
100
118
|
|
|
101
119
|
def _eval_heterogeneous_parameters(
|
|
102
120
|
self,
|
|
103
121
|
inputs: InputDim,
|
|
104
122
|
u: AbstractPINN,
|
|
105
123
|
params: Params[Array],
|
|
106
|
-
eq_params_heterogeneity:
|
|
107
|
-
|
|
108
|
-
|
|
124
|
+
eq_params_heterogeneity: PyTree[
|
|
125
|
+
Callable[[InputDim, AbstractPINN, Params[Array]], Array] | None
|
|
126
|
+
]
|
|
127
|
+
| None = None,
|
|
128
|
+
) -> PyTree[Array]:
|
|
109
129
|
if eq_params_heterogeneity is None:
|
|
110
130
|
return params.eq_params
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
eq_params_[k] = p
|
|
119
|
-
except KeyError:
|
|
120
|
-
# we authorize missing eq_params_heterogeneity key
|
|
121
|
-
# if its heterogeneity is None anyway
|
|
122
|
-
eq_params_[k] = p
|
|
131
|
+
eq_params_ = jax.tree.map(
|
|
132
|
+
lambda p, fun: ( # type: ignore
|
|
133
|
+
fun(inputs, u, params) if fun is not None else p
|
|
134
|
+
),
|
|
135
|
+
params.eq_params,
|
|
136
|
+
eq_params_heterogeneity,
|
|
137
|
+
)
|
|
123
138
|
return eq_params_
|
|
124
139
|
|
|
125
140
|
@partial(_decorator_heteregeneous_params)
|
|
@@ -128,7 +143,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
|
|
|
128
143
|
inputs: InputDim,
|
|
129
144
|
u: AbstractPINN,
|
|
130
145
|
params: Params[Array],
|
|
131
|
-
) ->
|
|
146
|
+
) -> Float[Array, " eq_dim"]:
|
|
132
147
|
evaluation = self.vectorial_dyn_loss_ponderation * self.equation(
|
|
133
148
|
inputs, u, params
|
|
134
149
|
)
|
|
@@ -147,7 +162,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
|
|
|
147
162
|
return evaluation
|
|
148
163
|
|
|
149
164
|
@abc.abstractmethod
|
|
150
|
-
def equation(self, *args, **kwargs):
|
|
165
|
+
def equation(self, *args: Any, **kwargs: Any) -> Float[Array, " eq_dim"]:
|
|
151
166
|
# TO IMPLEMENT
|
|
152
167
|
# Point-wise evaluation of the differential equation N[u](.)
|
|
153
168
|
raise NotImplementedError("You should implement your equation.")
|
|
@@ -183,7 +198,7 @@ class ODE(DynamicLoss[Float[Array, " 1"]]):
|
|
|
183
198
|
@abc.abstractmethod
|
|
184
199
|
def equation(
|
|
185
200
|
self, t: Float[Array, " 1"], u: AbstractPINN, params: Params[Array]
|
|
186
|
-
) ->
|
|
201
|
+
) -> Float[Array, " eq_dim"]:
|
|
187
202
|
r"""
|
|
188
203
|
The differential operator defining the ODE.
|
|
189
204
|
|
|
@@ -202,7 +217,7 @@ class ODE(DynamicLoss[Float[Array, " 1"]]):
|
|
|
202
217
|
|
|
203
218
|
Returns
|
|
204
219
|
-------
|
|
205
|
-
|
|
220
|
+
Float[Array, "eq_dim"]
|
|
206
221
|
The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t)$ evaluated at point `t`.
|
|
207
222
|
|
|
208
223
|
Raises
|
|
@@ -242,7 +257,7 @@ class PDEStatio(DynamicLoss[Float[Array, " dim"]]):
|
|
|
242
257
|
@abc.abstractmethod
|
|
243
258
|
def equation(
|
|
244
259
|
self, x: Float[Array, " dim"], u: AbstractPINN, params: Params[Array]
|
|
245
|
-
) ->
|
|
260
|
+
) -> Float[Array, " eq_dim"]:
|
|
246
261
|
r"""The differential operator defining the stationnary PDE.
|
|
247
262
|
|
|
248
263
|
!!! warning
|
|
@@ -260,7 +275,7 @@ class PDEStatio(DynamicLoss[Float[Array, " dim"]]):
|
|
|
260
275
|
|
|
261
276
|
Returns
|
|
262
277
|
-------
|
|
263
|
-
|
|
278
|
+
Float[Array, "eq_dim"]
|
|
264
279
|
The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](x)$ evaluated at point `x`.
|
|
265
280
|
|
|
266
281
|
Raises
|
|
@@ -303,7 +318,7 @@ class PDENonStatio(DynamicLoss[Float[Array, " 1 + dim"]]):
|
|
|
303
318
|
t_x: Float[Array, " 1 + dim"],
|
|
304
319
|
u: AbstractPINN,
|
|
305
320
|
params: Params[Array],
|
|
306
|
-
) ->
|
|
321
|
+
) -> Float[Array, " eq_dim"]:
|
|
307
322
|
r"""The differential operator defining the non-stationnary PDE.
|
|
308
323
|
|
|
309
324
|
!!! warning
|
|
@@ -320,7 +335,7 @@ class PDENonStatio(DynamicLoss[Float[Array, " 1 + dim"]]):
|
|
|
320
335
|
The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
|
|
321
336
|
Returns
|
|
322
337
|
-------
|
|
323
|
-
|
|
338
|
+
Float[Array, "eq_dim"]
|
|
324
339
|
The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t, x)$ evaluated at point `(t, x)`.
|
|
325
340
|
|
|
326
341
|
|