jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -21,61 +21,32 @@ else:
|
|
|
21
21
|
from equinox import AbstractClassVar
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def _decorator_heteregeneous_params(evaluate
|
|
24
|
+
def _decorator_heteregeneous_params(evaluate):
|
|
25
25
|
|
|
26
|
-
def
|
|
27
|
-
self,
|
|
26
|
+
def wrapper(*args):
|
|
27
|
+
self, inputs, u, params = args
|
|
28
28
|
_params = eqx.tree_at(
|
|
29
29
|
lambda p: p.eq_params,
|
|
30
30
|
params,
|
|
31
31
|
self._eval_heterogeneous_parameters(
|
|
32
|
-
|
|
32
|
+
inputs, u, params, self.eq_params_heterogeneity
|
|
33
33
|
),
|
|
34
34
|
)
|
|
35
35
|
new_args = args[:-1] + (_params,)
|
|
36
36
|
res = evaluate(*new_args)
|
|
37
37
|
return res
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
self, x, u, params = args
|
|
41
|
-
_params = eqx.tree_at(
|
|
42
|
-
lambda p: p.eq_params,
|
|
43
|
-
params,
|
|
44
|
-
self._eval_heterogeneous_parameters(
|
|
45
|
-
None, x, u, params, self.eq_params_heterogeneity
|
|
46
|
-
),
|
|
47
|
-
)
|
|
48
|
-
new_args = args[:-1] + (_params,)
|
|
49
|
-
res = evaluate(*new_args)
|
|
50
|
-
return res
|
|
51
|
-
|
|
52
|
-
def wrapper_pde_non_statio(*args):
|
|
53
|
-
self, t, x, u, params = args
|
|
54
|
-
_params = eqx.tree_at(
|
|
55
|
-
lambda p: p.eq_params,
|
|
56
|
-
params,
|
|
57
|
-
self._eval_heterogeneous_parameters(
|
|
58
|
-
t, x, u, params, self.eq_params_heterogeneity
|
|
59
|
-
),
|
|
60
|
-
)
|
|
61
|
-
new_args = args[:-1] + (_params,)
|
|
62
|
-
res = evaluate(*new_args)
|
|
63
|
-
return res
|
|
64
|
-
|
|
65
|
-
if eq_type == "ODE":
|
|
66
|
-
return wrapper_ode
|
|
67
|
-
elif eq_type == "Statio PDE":
|
|
68
|
-
return wrapper_pde_statio
|
|
69
|
-
elif eq_type == "Non-statio PDE":
|
|
70
|
-
return wrapper_pde_non_statio
|
|
39
|
+
return wrapper
|
|
71
40
|
|
|
72
41
|
|
|
73
42
|
class DynamicLoss(eqx.Module):
|
|
74
43
|
r"""
|
|
75
44
|
Abstract base class for dynamic losses. Implements the physical term:
|
|
45
|
+
|
|
76
46
|
$$
|
|
77
47
|
\mathcal{N}[u](t, x) = 0
|
|
78
48
|
$$
|
|
49
|
+
|
|
79
50
|
for **one** point $t$, $x$ or $(t, x)$, depending on the context.
|
|
80
51
|
|
|
81
52
|
Parameters
|
|
@@ -108,8 +79,7 @@ class DynamicLoss(eqx.Module):
|
|
|
108
79
|
|
|
109
80
|
def _eval_heterogeneous_parameters(
|
|
110
81
|
self,
|
|
111
|
-
|
|
112
|
-
x: Float[Array, "dim"],
|
|
82
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
113
83
|
u: eqx.Module,
|
|
114
84
|
params: Params | ParamsDict,
|
|
115
85
|
eq_params_heterogeneity: Dict[str, Callable | None] = None,
|
|
@@ -122,14 +92,7 @@ class DynamicLoss(eqx.Module):
|
|
|
122
92
|
if eq_params_heterogeneity[k] is None:
|
|
123
93
|
eq_params_[k] = p
|
|
124
94
|
else:
|
|
125
|
-
|
|
126
|
-
# signature will vary according to _eq_type
|
|
127
|
-
if self._eq_type == "ODE":
|
|
128
|
-
eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
|
|
129
|
-
elif self._eq_type == "Statio PDE":
|
|
130
|
-
eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
|
|
131
|
-
elif self._eq_type == "Non-statio PDE":
|
|
132
|
-
eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
|
|
95
|
+
eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params)
|
|
133
96
|
except KeyError:
|
|
134
97
|
# we authorize missing eq_params_heterogeneity key
|
|
135
98
|
# if its heterogeneity is None anyway
|
|
@@ -138,22 +101,17 @@ class DynamicLoss(eqx.Module):
|
|
|
138
101
|
|
|
139
102
|
def _evaluate(
|
|
140
103
|
self,
|
|
141
|
-
|
|
142
|
-
x: Float[Array, "dim"],
|
|
104
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
143
105
|
u: eqx.Module,
|
|
144
106
|
params: Params | ParamsDict,
|
|
145
107
|
) -> float:
|
|
146
|
-
|
|
147
|
-
if
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
else:
|
|
154
|
-
raise NotImplementedError("the equation type is not handled.")
|
|
155
|
-
|
|
156
|
-
return ans
|
|
108
|
+
evaluation = self.equation(inputs, u, params)
|
|
109
|
+
if len(evaluation.shape) == 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"The output of dynamic loss must be vectorial, "
|
|
112
|
+
"i.e. of shape (d,) with d >= 1"
|
|
113
|
+
)
|
|
114
|
+
return evaluation
|
|
157
115
|
|
|
158
116
|
@abc.abstractmethod
|
|
159
117
|
def equation(self, *args, **kwargs):
|
|
@@ -189,7 +147,7 @@ class ODE(DynamicLoss):
|
|
|
189
147
|
|
|
190
148
|
_eq_type: ClassVar[str] = "ODE"
|
|
191
149
|
|
|
192
|
-
@partial(_decorator_heteregeneous_params
|
|
150
|
+
@partial(_decorator_heteregeneous_params)
|
|
193
151
|
def evaluate(
|
|
194
152
|
self,
|
|
195
153
|
t: Float[Array, "1"],
|
|
@@ -197,7 +155,7 @@ class ODE(DynamicLoss):
|
|
|
197
155
|
params: Params | ParamsDict,
|
|
198
156
|
) -> float:
|
|
199
157
|
"""Here we call DynamicLoss._evaluate with x=None"""
|
|
200
|
-
return self._evaluate(t,
|
|
158
|
+
return self._evaluate(t, u, params)
|
|
201
159
|
|
|
202
160
|
@abc.abstractmethod
|
|
203
161
|
def equation(
|
|
@@ -258,12 +216,12 @@ class PDEStatio(DynamicLoss):
|
|
|
258
216
|
|
|
259
217
|
_eq_type: ClassVar[str] = "Statio PDE"
|
|
260
218
|
|
|
261
|
-
@partial(_decorator_heteregeneous_params
|
|
219
|
+
@partial(_decorator_heteregeneous_params)
|
|
262
220
|
def evaluate(
|
|
263
221
|
self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
|
|
264
222
|
) -> float:
|
|
265
223
|
"""Here we call the DynamicLoss._evaluate with t=None"""
|
|
266
|
-
return self._evaluate(
|
|
224
|
+
return self._evaluate(x, u, params)
|
|
267
225
|
|
|
268
226
|
@abc.abstractmethod
|
|
269
227
|
def equation(
|
|
@@ -323,23 +281,21 @@ class PDENonStatio(DynamicLoss):
|
|
|
323
281
|
|
|
324
282
|
_eq_type: ClassVar[str] = "Non-statio PDE"
|
|
325
283
|
|
|
326
|
-
@partial(_decorator_heteregeneous_params
|
|
284
|
+
@partial(_decorator_heteregeneous_params)
|
|
327
285
|
def evaluate(
|
|
328
286
|
self,
|
|
329
|
-
|
|
330
|
-
x: Float[Array, "dim"],
|
|
287
|
+
t_x: Float[Array, "1 + dim"],
|
|
331
288
|
u: eqx.Module,
|
|
332
289
|
params: Params | ParamsDict,
|
|
333
290
|
) -> float:
|
|
334
291
|
"""Here we call the DynamicLoss._evaluate with full arguments"""
|
|
335
|
-
ans = self._evaluate(
|
|
292
|
+
ans = self._evaluate(t_x, u, params)
|
|
336
293
|
return ans
|
|
337
294
|
|
|
338
295
|
@abc.abstractmethod
|
|
339
296
|
def equation(
|
|
340
297
|
self,
|
|
341
|
-
|
|
342
|
-
x: Float[Array, "dim"],
|
|
298
|
+
t_x: Float[Array, "1 + dim"],
|
|
343
299
|
u: eqx.Module,
|
|
344
300
|
params: Params | ParamsDict,
|
|
345
301
|
) -> float:
|
|
@@ -351,10 +307,8 @@ class PDENonStatio(DynamicLoss):
|
|
|
351
307
|
|
|
352
308
|
Parameters
|
|
353
309
|
----------
|
|
354
|
-
|
|
355
|
-
A
|
|
356
|
-
x : Float[Array, "d"]
|
|
357
|
-
A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
|
|
310
|
+
t_x : Float[Array, "1 + dim"]
|
|
311
|
+
A jnp array containing the concatenation of a time point and a point in $\Omega$
|
|
358
312
|
u : eqx.Module
|
|
359
313
|
The neural network.
|
|
360
314
|
params : Params | ParamsDict
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -56,6 +56,9 @@ class _LossODEAbstract(eqx.Module):
|
|
|
56
56
|
slice of u output(s) that is observed. This is useful for
|
|
57
57
|
multidimensional PINN, with partially observed outputs.
|
|
58
58
|
Default is None (whole output is observed).
|
|
59
|
+
params : InitVar[Params], default=None
|
|
60
|
+
The main Params object of the problem needed to instanciate the
|
|
61
|
+
DerivativeKeysODE if the latter is not specified.
|
|
59
62
|
"""
|
|
60
63
|
|
|
61
64
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
@@ -66,13 +69,21 @@ class _LossODEAbstract(eqx.Module):
|
|
|
66
69
|
initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
|
|
67
70
|
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
68
71
|
|
|
69
|
-
|
|
72
|
+
params: InitVar[Params] = eqx.field(default=None, kw_only=True)
|
|
73
|
+
|
|
74
|
+
def __post_init__(self, params=None):
|
|
70
75
|
if self.loss_weights is None:
|
|
71
76
|
self.loss_weights = LossWeightsODE()
|
|
72
77
|
|
|
73
78
|
if self.derivative_keys is None:
|
|
74
|
-
|
|
75
|
-
|
|
79
|
+
try:
|
|
80
|
+
# be default we only take gradient wrt nn_params
|
|
81
|
+
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
82
|
+
except ValueError as exc:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Problem at self.derivative_keys initialization "
|
|
85
|
+
f"received {self.derivative_keys=} and {params=}"
|
|
86
|
+
) from exc
|
|
76
87
|
if self.initial_condition is None:
|
|
77
88
|
warnings.warn(
|
|
78
89
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
@@ -131,6 +142,9 @@ class LossODE(_LossODEAbstract):
|
|
|
131
142
|
slice of u output(s) that is observed. This is useful for
|
|
132
143
|
multidimensional PINN, with partially observed outputs.
|
|
133
144
|
Default is None (whole output is observed).
|
|
145
|
+
params : InitVar[Params], default=None
|
|
146
|
+
The main Params object of the problem needed to instanciate the
|
|
147
|
+
DerivativeKeysODE if the latter is not specified.
|
|
134
148
|
u : eqx.Module
|
|
135
149
|
the PINN
|
|
136
150
|
dynamic_loss : DynamicLoss
|
|
@@ -152,8 +166,10 @@ class LossODE(_LossODEAbstract):
|
|
|
152
166
|
|
|
153
167
|
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
154
168
|
|
|
155
|
-
def __post_init__(self):
|
|
156
|
-
super().__post_init__(
|
|
169
|
+
def __post_init__(self, params=None):
|
|
170
|
+
super().__post_init__(
|
|
171
|
+
params=params
|
|
172
|
+
) # because __init__ or __post_init__ of Base
|
|
157
173
|
# class is not automatically called
|
|
158
174
|
|
|
159
175
|
self.vmap_in_axes = (0,)
|
|
@@ -193,7 +209,7 @@ class LossODE(_LossODEAbstract):
|
|
|
193
209
|
mse_dyn_loss = dynamic_loss_apply(
|
|
194
210
|
self.dynamic_loss.evaluate,
|
|
195
211
|
self.u,
|
|
196
|
-
|
|
212
|
+
temporal_batch,
|
|
197
213
|
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
198
214
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
199
215
|
self.loss_weights.dyn_loss,
|
|
@@ -211,7 +227,7 @@ class LossODE(_LossODEAbstract):
|
|
|
211
227
|
else:
|
|
212
228
|
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
213
229
|
t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
|
|
214
|
-
t0 = jnp.array(t0)
|
|
230
|
+
t0 = jnp.array([t0])
|
|
215
231
|
u0 = jnp.array(u0)
|
|
216
232
|
mse_initial_condition = jnp.mean(
|
|
217
233
|
self.loss_weights.initial_condition
|
|
@@ -300,6 +316,9 @@ class SystemLossODE(eqx.Module):
|
|
|
300
316
|
PINNs. Default is None. But if a value is given, all the entries of
|
|
301
317
|
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
302
318
|
if no particular slice is to be given.
|
|
319
|
+
params_dict : InitVar[ParamsDict], default=None
|
|
320
|
+
The main Params object of the problem needed to instanciate the
|
|
321
|
+
DerivativeKeysODE if the latter is not specified.
|
|
303
322
|
|
|
304
323
|
Raises
|
|
305
324
|
------
|
|
@@ -332,14 +351,16 @@ class SystemLossODE(eqx.Module):
|
|
|
332
351
|
loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
|
|
333
352
|
kw_only=True, default=None
|
|
334
353
|
)
|
|
354
|
+
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
355
|
+
|
|
335
356
|
u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
|
|
336
|
-
|
|
357
|
+
derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
|
|
337
358
|
|
|
338
359
|
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
339
360
|
# internally the loss weights are handled with a dictionary
|
|
340
361
|
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
341
362
|
|
|
342
|
-
def __post_init__(self, loss_weights):
|
|
363
|
+
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
343
364
|
# a dictionary that will be useful at different places
|
|
344
365
|
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
345
366
|
if self.initial_condition_dict is None:
|
|
@@ -369,14 +390,14 @@ class SystemLossODE(eqx.Module):
|
|
|
369
390
|
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
370
391
|
# derivative_keys_dict
|
|
371
392
|
|
|
372
|
-
#
|
|
373
|
-
#
|
|
374
|
-
#
|
|
375
|
-
|
|
376
|
-
# default values
|
|
377
|
-
for k in self.dynamic_loss_dict.keys():
|
|
393
|
+
# derivative keys for the u_constraints. Note that we create missing
|
|
394
|
+
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
395
|
+
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
396
|
+
for k in self.u_dict.keys():
|
|
378
397
|
if self.derivative_keys_dict[k] is None:
|
|
379
|
-
self.derivative_keys_dict[k] = DerivativeKeysODE(
|
|
398
|
+
self.derivative_keys_dict[k] = DerivativeKeysODE(
|
|
399
|
+
params=params_dict.extract_params(k)
|
|
400
|
+
)
|
|
380
401
|
|
|
381
402
|
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
382
403
|
|
|
@@ -397,12 +418,11 @@ class SystemLossODE(eqx.Module):
|
|
|
397
418
|
obs_slice=self.obs_slice_dict[i],
|
|
398
419
|
)
|
|
399
420
|
|
|
400
|
-
#
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
}
|
|
421
|
+
# derivative keys for the dynamic loss. Note that we create a
|
|
422
|
+
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
423
|
+
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
424
|
+
# happen inside it)
|
|
425
|
+
self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
|
|
406
426
|
|
|
407
427
|
def set_loss_weights(self, loss_weights_init):
|
|
408
428
|
"""
|
|
@@ -497,13 +517,13 @@ class SystemLossODE(eqx.Module):
|
|
|
497
517
|
batch.param_batch_dict, params_dict
|
|
498
518
|
)
|
|
499
519
|
|
|
500
|
-
def dyn_loss_for_one_key(dyn_loss,
|
|
520
|
+
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
501
521
|
"""This function is used in tree_map"""
|
|
502
522
|
return dynamic_loss_apply(
|
|
503
523
|
dyn_loss.evaluate,
|
|
504
524
|
self.u_dict,
|
|
505
|
-
|
|
506
|
-
_set_derivatives(params_dict,
|
|
525
|
+
temporal_batch,
|
|
526
|
+
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
507
527
|
vmap_in_axes_t + vmap_in_axes_params,
|
|
508
528
|
loss_weight,
|
|
509
529
|
u_type=PINN,
|
|
@@ -512,7 +532,6 @@ class SystemLossODE(eqx.Module):
|
|
|
512
532
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
513
533
|
dyn_loss_for_one_key,
|
|
514
534
|
self.dynamic_loss_dict,
|
|
515
|
-
self.derivative_keys_dyn_loss_dict,
|
|
516
535
|
self._loss_weights["dyn_loss"],
|
|
517
536
|
is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
|
|
518
537
|
# where plain (unregister pytree) node classes, we could not traverse
|