jinns 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +103 -65
- jinns/loss/_LossPDE.py +145 -77
- jinns/loss/_abstract_loss.py +64 -6
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +4 -1
- jinns/solver/_utils.py +17 -11
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/METADATA +14 -4
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/RECORD +28 -28
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from jaxtyping import Float, Array, PyTree
|
|
|
16
16
|
import jax
|
|
17
17
|
import jax.numpy as jnp
|
|
18
18
|
from jinns.parameters._params import EqParams
|
|
19
|
+
from jinns.nn import SPINN
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
# See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
|
|
@@ -38,6 +39,7 @@ def _decorator_heteregeneous_params(evaluate):
|
|
|
38
39
|
self._eval_heterogeneous_parameters(
|
|
39
40
|
inputs, u, params, self.eq_params_heterogeneity
|
|
40
41
|
),
|
|
42
|
+
is_leaf=lambda x: x is None,
|
|
41
43
|
)
|
|
42
44
|
new_args = args[:-1] + (_params,)
|
|
43
45
|
res = evaluate(*new_args)
|
|
@@ -152,7 +154,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
|
|
|
152
154
|
"The output of dynamic loss must be vectorial, "
|
|
153
155
|
"i.e. of shape (d,) with d >= 1"
|
|
154
156
|
)
|
|
155
|
-
if len(evaluation.shape) > 1:
|
|
157
|
+
if len(evaluation.shape) > 1 and not isinstance(u, SPINN):
|
|
156
158
|
warnings.warn(
|
|
157
159
|
"Return value from DynamicLoss' equation has more "
|
|
158
160
|
"than one dimension. This is in general a mistake (probably from "
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -28,13 +28,13 @@ from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysOD
|
|
|
28
28
|
from jinns.loss._loss_weights import LossWeightsODE
|
|
29
29
|
from jinns.loss._abstract_loss import AbstractLoss
|
|
30
30
|
from jinns.loss._loss_components import ODEComponents
|
|
31
|
+
from jinns.loss import ODE
|
|
31
32
|
from jinns.parameters._params import Params
|
|
32
33
|
from jinns.data._Batchs import ODEBatch
|
|
33
34
|
|
|
34
35
|
if TYPE_CHECKING:
|
|
35
36
|
# imports only used in type hints
|
|
36
37
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
37
|
-
from jinns.loss import ODE
|
|
38
38
|
|
|
39
39
|
InitialConditionUser = (
|
|
40
40
|
tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
|
|
@@ -65,39 +65,42 @@ class LossODE(
|
|
|
65
65
|
----------
|
|
66
66
|
u : eqx.Module
|
|
67
67
|
the PINN
|
|
68
|
-
dynamic_loss : ODE
|
|
68
|
+
dynamic_loss : tuple[ODE, ...] | ODE | None
|
|
69
69
|
the ODE dynamic part of the loss, basically the differential
|
|
70
70
|
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
71
71
|
`dynamic_loss.evaluate(t, u, params)`.
|
|
72
72
|
Can be None in order to access only some part of the evaluate call.
|
|
73
|
-
loss_weights : LossWeightsODE, default=None
|
|
73
|
+
loss_weights : LossWeightsODE | None, default=None
|
|
74
74
|
The loss weights for the differents term : dynamic loss,
|
|
75
75
|
initial condition and eventually observations if any.
|
|
76
76
|
Can be updated according to a specific algorithm. See
|
|
77
77
|
`update_weight_method`
|
|
78
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
78
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'] | None, default=None
|
|
79
79
|
Default is None meaning no update for loss weights. Otherwise a string
|
|
80
|
-
|
|
80
|
+
keep_initial_loss_weight_scales : bool, default=True
|
|
81
|
+
Only used if an update weight method is specified. It decides whether
|
|
82
|
+
the updated loss weights are multiplied by the initial `loss_weights`
|
|
83
|
+
passed by the user at initialization. This is useful to force some
|
|
84
|
+
scale difference between the adaptative loss weights even after the
|
|
85
|
+
update method is applied.
|
|
86
|
+
derivative_keys : DerivativeKeysODE | None, default=None
|
|
81
87
|
Specify which field of `params` should be differentiated for each
|
|
82
88
|
composant of the total loss. Particularily useful for inverse problems.
|
|
83
89
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
84
90
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
85
91
|
is `"nn_params"` for each composant of the loss.
|
|
86
|
-
initial_condition :
|
|
87
|
-
Float[Array, "n_cond "],
|
|
88
|
-
Float[Array, "n_cond dim"]
|
|
89
|
-
] |
|
|
90
|
-
tuple[int | float | Float[Array, " "],
|
|
91
|
-
int | float | Float[Array, " dim"]
|
|
92
|
-
], default=None
|
|
92
|
+
initial_condition : InitialConditionUser | None, default=None
|
|
93
93
|
Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
94
94
|
From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
|
|
95
|
-
obs_slice : EllipsisType | slice, default=None
|
|
95
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
96
96
|
Slice object specifying the begininning/ending
|
|
97
97
|
slice of u output(s) that is observed. This is useful for
|
|
98
98
|
multidimensional PINN, with partially observed outputs.
|
|
99
99
|
Default is None (whole output is observed).
|
|
100
|
-
|
|
100
|
+
**Note**: If several observation datasets are passed this arguments need to be set as a
|
|
101
|
+
tuple of jnp.slice objects with the same length as the number of
|
|
102
|
+
observation datasets
|
|
103
|
+
params : InitVar[Params[Array]] | None, default=None
|
|
101
104
|
The main Params object of the problem needed to instanciate the
|
|
102
105
|
DerivativeKeysODE if the latter is not specified.
|
|
103
106
|
Raises
|
|
@@ -109,22 +112,26 @@ class LossODE(
|
|
|
109
112
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
110
113
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
111
114
|
u: AbstractPINN
|
|
112
|
-
dynamic_loss: ODE | None
|
|
115
|
+
dynamic_loss: tuple[ODE | None, ...]
|
|
116
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
113
117
|
derivative_keys: DerivativeKeysODE
|
|
114
118
|
loss_weights: LossWeightsODE
|
|
115
119
|
initial_condition: InitialCondition | None
|
|
116
|
-
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
120
|
+
obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
|
|
117
121
|
params: InitVar[Params[Array] | None]
|
|
118
122
|
|
|
119
123
|
def __init__(
|
|
120
124
|
self,
|
|
121
125
|
*,
|
|
122
126
|
u: AbstractPINN,
|
|
123
|
-
dynamic_loss: ODE | None,
|
|
127
|
+
dynamic_loss: tuple[ODE, ...] | ODE | None,
|
|
124
128
|
loss_weights: LossWeightsODE | None = None,
|
|
125
129
|
derivative_keys: DerivativeKeysODE | None = None,
|
|
126
130
|
initial_condition: InitialConditionUser | None = None,
|
|
127
|
-
obs_slice: EllipsisType | slice
|
|
131
|
+
obs_slice: tuple[EllipsisType | slice, ...]
|
|
132
|
+
| EllipsisType
|
|
133
|
+
| slice
|
|
134
|
+
| None = None,
|
|
128
135
|
params: Params[Array] | None = None,
|
|
129
136
|
**kwargs: Any, # this is for arguments for super()
|
|
130
137
|
):
|
|
@@ -141,8 +148,6 @@ class LossODE(
|
|
|
141
148
|
f"received {derivative_keys=} and {params=}"
|
|
142
149
|
)
|
|
143
150
|
derivative_keys = DerivativeKeysODE(params=params)
|
|
144
|
-
else:
|
|
145
|
-
derivative_keys = derivative_keys
|
|
146
151
|
|
|
147
152
|
super().__init__(
|
|
148
153
|
loss_weights=self.loss_weights,
|
|
@@ -151,7 +156,10 @@ class LossODE(
|
|
|
151
156
|
**kwargs,
|
|
152
157
|
)
|
|
153
158
|
self.u = u
|
|
154
|
-
|
|
159
|
+
if not isinstance(dynamic_loss, tuple):
|
|
160
|
+
self.dynamic_loss = (dynamic_loss,)
|
|
161
|
+
else:
|
|
162
|
+
self.dynamic_loss = dynamic_loss
|
|
155
163
|
if self.update_weight_method is not None and jnp.any(
|
|
156
164
|
jnp.array(jax.tree.leaves(self.loss_weights)) == 0
|
|
157
165
|
):
|
|
@@ -231,10 +239,15 @@ class LossODE(
|
|
|
231
239
|
self.initial_condition = (t0, u0)
|
|
232
240
|
|
|
233
241
|
if obs_slice is None:
|
|
234
|
-
self.obs_slice = jnp.s_[...]
|
|
242
|
+
self.obs_slice = (jnp.s_[...],)
|
|
243
|
+
elif not isinstance(obs_slice, tuple):
|
|
244
|
+
self.obs_slice = (obs_slice,)
|
|
235
245
|
else:
|
|
236
246
|
self.obs_slice = obs_slice
|
|
237
247
|
|
|
248
|
+
if self.loss_weights is None:
|
|
249
|
+
self.loss_weights = LossWeightsODE()
|
|
250
|
+
|
|
238
251
|
def evaluate_by_terms(
|
|
239
252
|
self,
|
|
240
253
|
opt_params: Params[Array],
|
|
@@ -280,23 +293,29 @@ class LossODE(
|
|
|
280
293
|
cast(eqx.Module, batch.param_batch_dict), params
|
|
281
294
|
)
|
|
282
295
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
296
|
+
if self.dynamic_loss != (None,):
|
|
297
|
+
# Note, for the record, multiple dynamic losses
|
|
298
|
+
# have been introduced in MR 92
|
|
299
|
+
dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
300
|
+
jax.tree.map(
|
|
301
|
+
lambda d: lambda p: dynamic_loss_apply(
|
|
302
|
+
d.evaluate,
|
|
303
|
+
self.u,
|
|
304
|
+
temporal_batch,
|
|
305
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
306
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
307
|
+
),
|
|
308
|
+
self.dynamic_loss,
|
|
309
|
+
is_leaf=lambda x: isinstance(x, ODE), # do not traverse
|
|
310
|
+
# further than first level
|
|
293
311
|
)
|
|
294
312
|
)
|
|
295
313
|
else:
|
|
296
314
|
dyn_loss_fun = None
|
|
297
315
|
|
|
298
316
|
if self.initial_condition is not None:
|
|
299
|
-
# initial
|
|
317
|
+
# Note, for the record, multiple initial conditions for LossODEs
|
|
318
|
+
# have been introduced in MR 77
|
|
300
319
|
t0, u0 = self.initial_condition
|
|
301
320
|
|
|
302
321
|
# first construct the plain init loss no vmaping
|
|
@@ -331,34 +350,50 @@ class LossODE(
|
|
|
331
350
|
# if there is no parameter batch to vmap over we cannot call
|
|
332
351
|
# vmap because calling vmap must be done with at least one non
|
|
333
352
|
# None in_axes or out_axes
|
|
334
|
-
initial_condition_fun = initial_condition_fun_
|
|
353
|
+
initial_condition_fun = (initial_condition_fun_,)
|
|
335
354
|
else:
|
|
336
|
-
initial_condition_fun:
|
|
355
|
+
initial_condition_fun: (
|
|
356
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
357
|
+
) = (
|
|
337
358
|
lambda p: jnp.mean(
|
|
338
359
|
vmap(initial_condition_fun_, vmap_in_axes_params)(p)
|
|
339
|
-
)
|
|
360
|
+
),
|
|
340
361
|
)
|
|
362
|
+
# Note that since MR 92
|
|
363
|
+
# initial_condition_fun is formed as a tuple for
|
|
364
|
+
# consistency with dynamic and observation
|
|
365
|
+
# losses and more modularity for later
|
|
341
366
|
else:
|
|
342
367
|
initial_condition_fun = None
|
|
343
368
|
|
|
344
369
|
if batch.obs_batch_dict is not None:
|
|
345
|
-
#
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
lambda
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
370
|
+
# Note, for the record, multiple DGObs
|
|
371
|
+
# (leading to batch.obs_batch_dict being tuple | None)
|
|
372
|
+
# have been introduced in MR 92
|
|
373
|
+
if len(batch.obs_batch_dict) != len(self.obs_slice):
|
|
374
|
+
raise ValueError(
|
|
375
|
+
"There must be the same number of "
|
|
376
|
+
"observation datasets as the number of "
|
|
377
|
+
"obs_slice"
|
|
378
|
+
)
|
|
379
|
+
params_obs = jax.tree.map(
|
|
380
|
+
lambda d: update_eq_params(params, d["eq_params"]),
|
|
381
|
+
batch.obs_batch_dict,
|
|
382
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
383
|
+
)
|
|
384
|
+
obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
385
|
+
jax.tree.map(
|
|
386
|
+
lambda d, slice_: lambda po: observations_loss_apply(
|
|
387
|
+
self.u,
|
|
388
|
+
d["pinn_in"],
|
|
389
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
390
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
391
|
+
d["val"],
|
|
392
|
+
slice_,
|
|
393
|
+
),
|
|
394
|
+
batch.obs_batch_dict,
|
|
361
395
|
self.obs_slice,
|
|
396
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
362
397
|
)
|
|
363
398
|
)
|
|
364
399
|
else:
|
|
@@ -366,33 +401,36 @@ class LossODE(
|
|
|
366
401
|
obs_loss_fun = None
|
|
367
402
|
|
|
368
403
|
# get the unweighted mses for each loss term as well as the gradients
|
|
369
|
-
all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
|
|
404
|
+
all_funs: ODEComponents[tuple[Callable[[Params[Array]], Array], ...] | None] = (
|
|
370
405
|
ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
|
|
371
406
|
)
|
|
372
|
-
all_params: ODEComponents[Params[Array] | None] = ODEComponents(
|
|
373
|
-
params,
|
|
407
|
+
all_params: ODEComponents[tuple[Params[Array], ...] | None] = ODEComponents(
|
|
408
|
+
jax.tree.map(lambda l: params, dyn_loss_fun),
|
|
409
|
+
jax.tree.map(lambda l: params, initial_condition_fun),
|
|
410
|
+
params_obs,
|
|
374
411
|
)
|
|
375
412
|
|
|
376
|
-
# Note that the lambda functions below are with type: ignore just
|
|
377
|
-
# because the lambda are not type annotated, but there is no proper way
|
|
378
|
-
# to do this and we should assign the lambda to a type hinted variable
|
|
379
|
-
# before hand: this is not practical, let us not get mad at this
|
|
380
413
|
mses_grads = jax.tree.map(
|
|
381
414
|
self.get_gradients,
|
|
382
415
|
all_funs,
|
|
383
416
|
all_params,
|
|
384
417
|
is_leaf=lambda x: x is None,
|
|
385
418
|
)
|
|
386
|
-
|
|
419
|
+
# NOTE the is_leaf below is more complex since it must pass possible the tuple
|
|
420
|
+
# of dyn_loss and then stops (but also account it should not stop when
|
|
421
|
+
# the tuple of dyn_loss is of length 2)
|
|
387
422
|
mses = jax.tree.map(
|
|
388
|
-
lambda leaf: leaf[0],
|
|
423
|
+
lambda leaf: leaf[0],
|
|
389
424
|
mses_grads,
|
|
390
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
425
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
426
|
+
and len(x) == 2
|
|
427
|
+
and isinstance(x[1], Params),
|
|
391
428
|
)
|
|
392
429
|
grads = jax.tree.map(
|
|
393
|
-
lambda leaf: leaf[1],
|
|
430
|
+
lambda leaf: leaf[1],
|
|
394
431
|
mses_grads,
|
|
395
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
432
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
433
|
+
and len(x) == 2
|
|
434
|
+
and isinstance(x[1], Params),
|
|
396
435
|
)
|
|
397
|
-
|
|
398
436
|
return mses, grads
|