jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_CubicMeshPDENonStatio.py +156 -28
- jinns/data/_CubicMeshPDEStatio.py +132 -24
- jinns/loss/_DynamicLossAbstract.py +30 -2
- jinns/loss/_LossODE.py +177 -64
- jinns/loss/_LossPDE.py +146 -68
- jinns/loss/__init__.py +4 -0
- jinns/loss/_abstract_loss.py +116 -3
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +34 -24
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +72 -16
- jinns/parameters/_params.py +8 -0
- jinns/solver/_solve.py +141 -46
- jinns/utils/_containers.py +5 -2
- jinns/utils/_types.py +12 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/METADATA +5 -2
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/RECORD +22 -20
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/WHEEL +1 -1
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -7,7 +7,7 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
from dataclasses import InitVar
|
|
10
|
-
from typing import TYPE_CHECKING, TypedDict
|
|
10
|
+
from typing import TYPE_CHECKING, TypedDict, Callable
|
|
11
11
|
from types import EllipsisType
|
|
12
12
|
import abc
|
|
13
13
|
import warnings
|
|
@@ -19,6 +19,7 @@ from jaxtyping import Float, Array
|
|
|
19
19
|
from jinns.loss._loss_utils import (
|
|
20
20
|
dynamic_loss_apply,
|
|
21
21
|
observations_loss_apply,
|
|
22
|
+
initial_condition_check,
|
|
22
23
|
)
|
|
23
24
|
from jinns.parameters._params import (
|
|
24
25
|
_get_vmap_in_axes_params,
|
|
@@ -27,10 +28,11 @@ from jinns.parameters._params import (
|
|
|
27
28
|
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
28
29
|
from jinns.loss._loss_weights import LossWeightsODE
|
|
29
30
|
from jinns.loss._abstract_loss import AbstractLoss
|
|
31
|
+
from jinns.loss._loss_components import ODEComponents
|
|
32
|
+
from jinns.parameters._params import Params
|
|
30
33
|
|
|
31
34
|
if TYPE_CHECKING:
|
|
32
35
|
# imports only used in type hints
|
|
33
|
-
from jinns.parameters._params import Params
|
|
34
36
|
from jinns.data._Batchs import ODEBatch
|
|
35
37
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
36
38
|
from jinns.loss import ODE
|
|
@@ -42,22 +44,32 @@ if TYPE_CHECKING:
|
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
class _LossODEAbstract(AbstractLoss):
|
|
45
|
-
"""
|
|
47
|
+
r"""
|
|
46
48
|
Parameters
|
|
47
49
|
----------
|
|
48
50
|
|
|
49
51
|
loss_weights : LossWeightsODE, default=None
|
|
50
52
|
The loss weights for the differents term : dynamic loss,
|
|
51
|
-
initial condition and eventually observations if any.
|
|
52
|
-
|
|
53
|
+
initial condition and eventually observations if any.
|
|
54
|
+
Can be updated according to a specific algorithm. See
|
|
55
|
+
`update_weight_method`
|
|
56
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
57
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
53
58
|
derivative_keys : DerivativeKeysODE, default=None
|
|
54
59
|
Specify which field of `params` should be differentiated for each
|
|
55
60
|
composant of the total loss. Particularily useful for inverse problems.
|
|
56
61
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
57
62
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
58
63
|
is `"nn_params"` for each composant of the loss.
|
|
59
|
-
initial_condition : tuple[
|
|
60
|
-
|
|
64
|
+
initial_condition : tuple[
|
|
65
|
+
Float[Array, "n_cond "],
|
|
66
|
+
Float[Array, "n_cond dim"]
|
|
67
|
+
] |
|
|
68
|
+
tuple[int | float | Float[Array, " "],
|
|
69
|
+
int | float | Float[Array, " dim"]
|
|
70
|
+
] | None, default=None
|
|
71
|
+
Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
72
|
+
From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
|
|
61
73
|
obs_slice : EllipsisType | slice | None, default=None
|
|
62
74
|
Slice object specifying the begininning/ending
|
|
63
75
|
slice of u output(s) that is observed. This is useful for
|
|
@@ -74,7 +86,9 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
74
86
|
derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
|
|
75
87
|
loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
|
|
76
88
|
initial_condition: (
|
|
77
|
-
tuple[
|
|
89
|
+
tuple[Float[Array, " n_cond 1"], Float[Array, " n_cond dim"]]
|
|
90
|
+
| tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
|
|
91
|
+
| None
|
|
78
92
|
) = eqx.field(kw_only=True, default=None)
|
|
79
93
|
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
80
94
|
kw_only=True, default=None, static=True
|
|
@@ -108,18 +122,60 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
108
122
|
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
109
123
|
f"{self.initial_condition} was passed."
|
|
110
124
|
)
|
|
111
|
-
# some checks/reshaping for t0
|
|
125
|
+
# some checks/reshaping for t0 and u0
|
|
112
126
|
t0, u0 = self.initial_condition
|
|
113
127
|
if isinstance(t0, Array):
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
128
|
+
# at the end we want to end up with t0 of shape (:, 1) to account for
|
|
129
|
+
# possibly several data points
|
|
130
|
+
if t0.ndim <= 1:
|
|
131
|
+
# in this case we assume t0 belongs one (initial)
|
|
132
|
+
# condition
|
|
133
|
+
t0 = initial_condition_check(t0, dim_size=1)[
|
|
134
|
+
None, :
|
|
135
|
+
] # make a (1, 1) here
|
|
136
|
+
if t0.ndim > 2:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"It t0 is an Array, it represents n_cond"
|
|
139
|
+
" imposed conditions and must be of shape (n_cond, 1)"
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
# in this case t0 clearly represents one (initial) condition
|
|
143
|
+
t0 = initial_condition_check(t0, dim_size=1)[
|
|
144
|
+
None, :
|
|
145
|
+
] # make a (1, 1) here
|
|
146
|
+
if isinstance(u0, Array):
|
|
147
|
+
# at the end we want to end up with u0 of shape (:, dim) to account for
|
|
148
|
+
# possibly several data points
|
|
149
|
+
if not u0.shape:
|
|
150
|
+
# in this case we assume u0 belongs to one (initial)
|
|
151
|
+
# condition
|
|
152
|
+
u0 = initial_condition_check(u0, dim_size=1)[
|
|
153
|
+
None, :
|
|
154
|
+
] # make a (1, 1) here
|
|
155
|
+
elif u0.ndim == 1:
|
|
156
|
+
# in this case we assume u0 belongs to one (initial)
|
|
157
|
+
# condition
|
|
158
|
+
u0 = initial_condition_check(u0, dim_size=u0.shape[0])[
|
|
159
|
+
None, :
|
|
160
|
+
] # make a (1, dim) here
|
|
161
|
+
if u0.ndim > 2:
|
|
117
162
|
raise ValueError(
|
|
118
|
-
|
|
119
|
-
|
|
163
|
+
"It u0 is an Array, it represents n_cond "
|
|
164
|
+
"imposed conditions and must be of shape (n_cond, dim)"
|
|
120
165
|
)
|
|
121
|
-
|
|
122
|
-
|
|
166
|
+
else:
|
|
167
|
+
# at the end we want to end up with u0 of shape (:, dim) to account for
|
|
168
|
+
# possibly several data points
|
|
169
|
+
u0 = initial_condition_check(u0, dim_size=None)[
|
|
170
|
+
None, :
|
|
171
|
+
] # make a (1, 1) here
|
|
172
|
+
|
|
173
|
+
if t0.shape[0] != u0.shape[0] or t0.ndim != u0.ndim:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"t0 and u0 must represent a same number of initial"
|
|
176
|
+
" conditial conditions"
|
|
177
|
+
)
|
|
178
|
+
|
|
123
179
|
self.initial_condition = (t0, u0)
|
|
124
180
|
|
|
125
181
|
if self.obs_slice is None:
|
|
@@ -154,8 +210,11 @@ class LossODE(_LossODEAbstract):
|
|
|
154
210
|
----------
|
|
155
211
|
loss_weights : LossWeightsODE, default=None
|
|
156
212
|
The loss weights for the differents term : dynamic loss,
|
|
157
|
-
initial condition and eventually observations if any.
|
|
158
|
-
|
|
213
|
+
initial condition and eventually observations if any.
|
|
214
|
+
Can be updated according to a specific algorithm. See
|
|
215
|
+
`update_weight_method`
|
|
216
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
217
|
+
Default is None meaning no update for loss weights. Otherwise a string
|
|
159
218
|
derivative_keys : DerivativeKeysODE, default=None
|
|
160
219
|
Specify which field of `params` should be differentiated for each
|
|
161
220
|
composant of the total loss. Particularily useful for inverse problems.
|
|
@@ -204,21 +263,26 @@ class LossODE(_LossODEAbstract):
|
|
|
204
263
|
def __call__(self, *args, **kwargs):
|
|
205
264
|
return self.evaluate(*args, **kwargs)
|
|
206
265
|
|
|
207
|
-
def
|
|
266
|
+
def evaluate_by_terms(
|
|
208
267
|
self, params: Params[Array], batch: ODEBatch
|
|
209
|
-
) -> tuple[
|
|
268
|
+
) -> tuple[
|
|
269
|
+
ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
|
|
270
|
+
]:
|
|
210
271
|
"""
|
|
211
272
|
Evaluate the loss function at a batch of points for given parameters.
|
|
212
273
|
|
|
274
|
+
We retrieve two PyTrees with loss values and gradients for each term
|
|
213
275
|
|
|
214
276
|
Parameters
|
|
215
277
|
---------
|
|
216
278
|
params
|
|
217
279
|
Parameters at which the loss is evaluated
|
|
218
280
|
batch
|
|
219
|
-
Composed of a batch of
|
|
220
|
-
|
|
221
|
-
|
|
281
|
+
Composed of a batch of points in the
|
|
282
|
+
domain, a batch of points in the domain
|
|
283
|
+
border and an optional additional batch of parameters (eg. for
|
|
284
|
+
metamodeling) and an optional additional batch of observed
|
|
285
|
+
inputs/outputs/parameters
|
|
222
286
|
"""
|
|
223
287
|
temporal_batch = batch.temporal_batch
|
|
224
288
|
|
|
@@ -233,71 +297,120 @@ class LossODE(_LossODEAbstract):
|
|
|
233
297
|
|
|
234
298
|
## dynamic part
|
|
235
299
|
if self.dynamic_loss is not None:
|
|
236
|
-
|
|
237
|
-
self.dynamic_loss.evaluate,
|
|
300
|
+
dyn_loss_fun = lambda p: dynamic_loss_apply(
|
|
301
|
+
self.dynamic_loss.evaluate, # type: ignore
|
|
238
302
|
self.u,
|
|
239
303
|
temporal_batch,
|
|
240
|
-
_set_derivatives(
|
|
304
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
|
|
241
305
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
242
|
-
self.loss_weights.dyn_loss, # type: ignore
|
|
243
306
|
)
|
|
244
307
|
else:
|
|
245
|
-
|
|
308
|
+
dyn_loss_fun = None
|
|
246
309
|
|
|
247
310
|
# initial condition
|
|
248
311
|
if self.initial_condition is not None:
|
|
249
|
-
vmap_in_axes = (None,) + vmap_in_axes_params
|
|
250
|
-
if not jax.tree_util.tree_leaves(vmap_in_axes):
|
|
251
|
-
# test if only None in vmap_in_axes to avoid the value error:
|
|
252
|
-
# `vmap must have at least one non-None value in in_axes`
|
|
253
|
-
v_u = self.u
|
|
254
|
-
else:
|
|
255
|
-
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
256
312
|
t0, u0 = self.initial_condition
|
|
257
313
|
u0 = jnp.array(u0)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
)
|
|
269
|
-
- u0
|
|
314
|
+
|
|
315
|
+
# first construct the plain init loss no vmaping
|
|
316
|
+
initial_condition_fun__ = lambda t, u, p: jnp.sum(
|
|
317
|
+
(
|
|
318
|
+
self.u(
|
|
319
|
+
t,
|
|
320
|
+
_set_derivatives(
|
|
321
|
+
p,
|
|
322
|
+
self.derivative_keys.initial_condition, # type: ignore
|
|
323
|
+
),
|
|
270
324
|
)
|
|
271
|
-
|
|
272
|
-
axis=-1,
|
|
325
|
+
- u
|
|
273
326
|
)
|
|
327
|
+
** 2,
|
|
328
|
+
axis=0,
|
|
329
|
+
)
|
|
330
|
+
# now vmap over the number of conditions (first dim of t0 and u0)
|
|
331
|
+
# and take the mean
|
|
332
|
+
initial_condition_fun_ = lambda p: jnp.mean(
|
|
333
|
+
vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
|
|
274
334
|
)
|
|
335
|
+
# now vmap over the the possible batch of parameters and take the
|
|
336
|
+
# average. Note that we then finally have a cartesian product
|
|
337
|
+
# between the batch of parameters (if any) and the number of
|
|
338
|
+
# conditions (if any)
|
|
339
|
+
if not jax.tree_util.tree_leaves(vmap_in_axes_params):
|
|
340
|
+
# if there is no parameter batch to vmap over we cannot call
|
|
341
|
+
# vmap because calling vmap must be done with at least one non
|
|
342
|
+
# None in_axes or out_axes
|
|
343
|
+
initial_condition_fun = initial_condition_fun_
|
|
344
|
+
else:
|
|
345
|
+
initial_condition_fun = lambda p: jnp.mean(
|
|
346
|
+
vmap(initial_condition_fun_, vmap_in_axes_params)(p)
|
|
347
|
+
)
|
|
275
348
|
else:
|
|
276
|
-
|
|
349
|
+
initial_condition_fun = None
|
|
277
350
|
|
|
278
351
|
if batch.obs_batch_dict is not None:
|
|
279
352
|
# update params with the batches of observed params
|
|
280
|
-
|
|
353
|
+
params_obs = _update_eq_params_dict(
|
|
354
|
+
params, batch.obs_batch_dict["eq_params"]
|
|
355
|
+
)
|
|
281
356
|
|
|
282
357
|
# MSE loss wrt to an observed batch
|
|
283
|
-
|
|
358
|
+
obs_loss_fun = lambda po: observations_loss_apply(
|
|
284
359
|
self.u,
|
|
285
360
|
batch.obs_batch_dict["pinn_in"],
|
|
286
|
-
_set_derivatives(
|
|
361
|
+
_set_derivatives(po, self.derivative_keys.observations), # type: ignore
|
|
287
362
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
288
363
|
batch.obs_batch_dict["val"],
|
|
289
|
-
self.loss_weights.observations, # type: ignore
|
|
290
364
|
self.obs_slice,
|
|
291
365
|
)
|
|
292
366
|
else:
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
}
|
|
367
|
+
params_obs = None
|
|
368
|
+
obs_loss_fun = None
|
|
369
|
+
|
|
370
|
+
# get the unweighted mses for each loss term as well as the gradients
|
|
371
|
+
all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
|
|
372
|
+
ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
|
|
373
|
+
)
|
|
374
|
+
all_params: ODEComponents[Params[Array] | None] = ODEComponents(
|
|
375
|
+
params, params, params_obs
|
|
303
376
|
)
|
|
377
|
+
mses_grads = jax.tree.map(
|
|
378
|
+
lambda fun, params: self.get_gradients(fun, params),
|
|
379
|
+
all_funs,
|
|
380
|
+
all_params,
|
|
381
|
+
is_leaf=lambda x: x is None,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
mses = jax.tree.map(
|
|
385
|
+
lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
386
|
+
)
|
|
387
|
+
grads = jax.tree.map(
|
|
388
|
+
lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
return mses, grads
|
|
392
|
+
|
|
393
|
+
def evaluate(
|
|
394
|
+
self, params: Params[Array], batch: ODEBatch
|
|
395
|
+
) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
|
|
396
|
+
"""
|
|
397
|
+
Evaluate the loss function at a batch of points for given parameters.
|
|
398
|
+
|
|
399
|
+
We retrieve the total value itself and a PyTree with loss values for each term
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
---------
|
|
403
|
+
params
|
|
404
|
+
Parameters at which the loss is evaluated
|
|
405
|
+
batch
|
|
406
|
+
Composed of a batch of points in the
|
|
407
|
+
domain, a batch of points in the domain
|
|
408
|
+
border and an optional additional batch of parameters (eg. for
|
|
409
|
+
metamodeling) and an optional additional batch of observed
|
|
410
|
+
inputs/outputs/parameters
|
|
411
|
+
"""
|
|
412
|
+
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
413
|
+
|
|
414
|
+
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
415
|
+
|
|
416
|
+
return loss_val, loss_terms
|