jinns 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. jinns/data/_AbstractDataGenerator.py +1 -1
  2. jinns/data/_Batchs.py +47 -13
  3. jinns/data/_CubicMeshPDENonStatio.py +55 -34
  4. jinns/data/_CubicMeshPDEStatio.py +63 -35
  5. jinns/data/_DataGeneratorODE.py +48 -22
  6. jinns/data/_DataGeneratorObservations.py +75 -32
  7. jinns/data/_DataGeneratorParameter.py +152 -101
  8. jinns/data/__init__.py +2 -1
  9. jinns/data/_utils.py +22 -10
  10. jinns/loss/_DynamicLoss.py +21 -20
  11. jinns/loss/_DynamicLossAbstract.py +51 -36
  12. jinns/loss/_LossODE.py +139 -184
  13. jinns/loss/_LossPDE.py +440 -358
  14. jinns/loss/_abstract_loss.py +60 -25
  15. jinns/loss/_loss_components.py +4 -25
  16. jinns/loss/_loss_weight_updates.py +6 -7
  17. jinns/loss/_loss_weights.py +34 -35
  18. jinns/nn/_abstract_pinn.py +0 -2
  19. jinns/nn/_hyperpinn.py +34 -23
  20. jinns/nn/_mlp.py +5 -4
  21. jinns/nn/_pinn.py +1 -16
  22. jinns/nn/_ppinn.py +5 -16
  23. jinns/nn/_save_load.py +11 -4
  24. jinns/nn/_spinn.py +1 -16
  25. jinns/nn/_spinn_mlp.py +5 -5
  26. jinns/nn/_utils.py +33 -38
  27. jinns/parameters/__init__.py +3 -1
  28. jinns/parameters/_derivative_keys.py +99 -41
  29. jinns/parameters/_params.py +50 -25
  30. jinns/solver/_solve.py +3 -3
  31. jinns/utils/_DictToModuleMeta.py +66 -0
  32. jinns/utils/_ItemizableModule.py +19 -0
  33. jinns/utils/__init__.py +2 -1
  34. jinns/utils/_types.py +25 -15
  35. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  36. jinns-1.6.0.dist-info/RECORD +57 -0
  37. jinns-1.5.1.dist-info/RECORD +0 -55
  38. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  39. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  40. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  41. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py CHANGED
@@ -7,9 +7,8 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  from dataclasses import InitVar
10
- from typing import TYPE_CHECKING, TypedDict, Callable
10
+ from typing import TYPE_CHECKING, Callable, Any, cast
11
11
  from types import EllipsisType
12
- import abc
13
12
  import warnings
14
13
  import jax
15
14
  import jax.numpy as jnp
@@ -23,31 +22,51 @@ from jinns.loss._loss_utils import (
23
22
  )
24
23
  from jinns.parameters._params import (
25
24
  _get_vmap_in_axes_params,
26
- _update_eq_params_dict,
25
+ update_eq_params,
27
26
  )
28
27
  from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
29
28
  from jinns.loss._loss_weights import LossWeightsODE
30
29
  from jinns.loss._abstract_loss import AbstractLoss
31
30
  from jinns.loss._loss_components import ODEComponents
32
31
  from jinns.parameters._params import Params
32
+ from jinns.data._Batchs import ODEBatch
33
33
 
34
34
  if TYPE_CHECKING:
35
35
  # imports only used in type hints
36
- from jinns.data._Batchs import ODEBatch
37
36
  from jinns.nn._abstract_pinn import AbstractPINN
38
37
  from jinns.loss import ODE
39
38
 
40
- class LossDictODE(TypedDict):
41
- dyn_loss: Float[Array, " "]
42
- initial_condition: Float[Array, " "]
43
- observations: Float[Array, " "]
39
+ InitialConditionUser = (
40
+ tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
41
+ | tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
42
+ )
43
+
44
+ InitialCondition = (
45
+ tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
46
+ | tuple[Float[Array, " "], Float[Array, " dim"]]
47
+ )
48
+
49
+
50
+ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]]):
51
+ r"""Loss object for an ordinary differential equation
52
+
53
+ $$
54
+ \mathcal{N}[u](t) = 0, \forall t \in I
55
+ $$
56
+
57
+ where $\mathcal{N}[\cdot]$ is a differential operator and the
58
+ initial condition is $u(t_0)=u_0$.
44
59
 
45
60
 
46
- class _LossODEAbstract(AbstractLoss):
47
- r"""
48
61
  Parameters
49
62
  ----------
50
-
63
+ u : eqx.Module
64
+ the PINN
65
+ dynamic_loss : ODE
66
+ the ODE dynamic part of the loss, basically the differential
67
+ operator $\mathcal{N}[u](t)$. Should implement a method
68
+ `dynamic_loss.evaluate(t, u, params)`.
69
+ Can be None in order to access only some part of the evaluate call.
51
70
  loss_weights : LossWeightsODE, default=None
52
71
  The loss weights for the differents term : dynamic loss,
53
72
  initial condition and eventually observations if any.
@@ -67,10 +86,10 @@ class _LossODEAbstract(AbstractLoss):
67
86
  ] |
68
87
  tuple[int | float | Float[Array, " "],
69
88
  int | float | Float[Array, " dim"]
70
- ] | None, default=None
89
+ ], default=None
71
90
  Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
72
91
  From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
73
- obs_slice : EllipsisType | slice | None, default=None
92
+ obs_slice : EllipsisType | slice, default=None
74
93
  Slice object specifying the begininning/ending
75
94
  slice of u output(s) that is observed. This is useful for
76
95
  multidimensional PINN, with partially observed outputs.
@@ -78,52 +97,69 @@ class _LossODEAbstract(AbstractLoss):
78
97
  params : InitVar[Params[Array]], default=None
79
98
  The main Params object of the problem needed to instanciate the
80
99
  DerivativeKeysODE if the latter is not specified.
100
+ Raises
101
+ ------
102
+ ValueError
103
+ if initial condition is not a tuple.
81
104
  """
82
105
 
83
106
  # NOTE static=True only for leaf attributes that are not valid JAX types
84
107
  # (ie. jax.Array cannot be static) and that we do not expect to change
85
- # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
86
- derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
87
- loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
88
- initial_condition: (
89
- tuple[Float[Array, " n_cond 1"], Float[Array, " n_cond dim"]]
90
- | tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
91
- | None
92
- ) = eqx.field(kw_only=True, default=None)
93
- obs_slice: EllipsisType | slice | None = eqx.field(
94
- kw_only=True, default=None, static=True
95
- )
96
-
97
- params: InitVar[Params[Array]] = eqx.field(default=None, kw_only=True)
98
-
99
- def __post_init__(self, params: Params[Array] | None = None):
100
- if self.loss_weights is None:
108
+ u: AbstractPINN
109
+ dynamic_loss: ODE | None
110
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
111
+ derivative_keys: DerivativeKeysODE
112
+ loss_weights: LossWeightsODE
113
+ initial_condition: InitialCondition | None
114
+ obs_slice: EllipsisType | slice = eqx.field(static=True)
115
+ params: InitVar[Params[Array] | None]
116
+
117
+ def __init__(
118
+ self,
119
+ *,
120
+ u: AbstractPINN,
121
+ dynamic_loss: ODE | None,
122
+ loss_weights: LossWeightsODE | None = None,
123
+ derivative_keys: DerivativeKeysODE | None = None,
124
+ initial_condition: InitialConditionUser | None = None,
125
+ obs_slice: EllipsisType | slice | None = None,
126
+ params: Params[Array] | None = None,
127
+ **kwargs: Any, # this is for arguments for super()
128
+ ):
129
+ if loss_weights is None:
101
130
  self.loss_weights = LossWeightsODE()
131
+ else:
132
+ self.loss_weights = loss_weights
102
133
 
103
- if self.derivative_keys is None:
134
+ super().__init__(loss_weights=self.loss_weights, **kwargs)
135
+ self.u = u
136
+ self.dynamic_loss = dynamic_loss
137
+ self.vmap_in_axes = (0,)
138
+ if derivative_keys is None:
104
139
  # by default we only take gradient wrt nn_params
105
140
  if params is None:
106
141
  raise ValueError(
107
- "Problem at self.derivative_keys initialization "
108
- f"received {self.derivative_keys=} and {params=}"
142
+ "Problem at derivative_keys initialization "
143
+ f"received {derivative_keys=} and {params=}"
109
144
  )
110
145
  self.derivative_keys = DerivativeKeysODE(params=params)
111
- if self.initial_condition is None:
146
+ else:
147
+ self.derivative_keys = derivative_keys
148
+
149
+ if initial_condition is None:
112
150
  warnings.warn(
113
151
  "Initial condition wasn't provided. Be sure to cover for that"
114
152
  "case (e.g by. hardcoding it into the PINN output)."
115
153
  )
154
+ self.initial_condition = initial_condition
116
155
  else:
117
- if (
118
- not isinstance(self.initial_condition, tuple)
119
- or len(self.initial_condition) != 2
120
- ):
156
+ if len(initial_condition) != 2:
121
157
  raise ValueError(
122
158
  "Initial condition should be a tuple of len 2 with (t0, u0), "
123
- f"{self.initial_condition} was passed."
159
+ f"{initial_condition} was passed."
124
160
  )
125
161
  # some checks/reshaping for t0 and u0
126
- t0, u0 = self.initial_condition
162
+ t0, u0 = initial_condition
127
163
  if isinstance(t0, Array):
128
164
  # at the end we want to end up with t0 of shape (:, 1) to account for
129
165
  # possibly several data points
@@ -178,90 +214,10 @@ class _LossODEAbstract(AbstractLoss):
178
214
 
179
215
  self.initial_condition = (t0, u0)
180
216
 
181
- if self.obs_slice is None:
217
+ if obs_slice is None:
182
218
  self.obs_slice = jnp.s_[...]
183
-
184
- if self.loss_weights is None:
185
- self.loss_weights = LossWeightsODE()
186
-
187
- @abc.abstractmethod
188
- def __call__(self, *_, **__):
189
- pass
190
-
191
- @abc.abstractmethod
192
- def evaluate(
193
- self: eqx.Module, params: Params[Array], batch: ODEBatch
194
- ) -> tuple[Float[Array, " "], LossDictODE]:
195
- raise NotImplementedError
196
-
197
-
198
- class LossODE(_LossODEAbstract):
199
- r"""Loss object for an ordinary differential equation
200
-
201
- $$
202
- \mathcal{N}[u](t) = 0, \forall t \in I
203
- $$
204
-
205
- where $\mathcal{N}[\cdot]$ is a differential operator and the
206
- initial condition is $u(t_0)=u_0$.
207
-
208
-
209
- Parameters
210
- ----------
211
- loss_weights : LossWeightsODE, default=None
212
- The loss weights for the differents term : dynamic loss,
213
- initial condition and eventually observations if any.
214
- Can be updated according to a specific algorithm. See
215
- `update_weight_method`
216
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
217
- Default is None meaning no update for loss weights. Otherwise a string
218
- derivative_keys : DerivativeKeysODE, default=None
219
- Specify which field of `params` should be differentiated for each
220
- composant of the total loss. Particularily useful for inverse problems.
221
- Fields can be "nn_params", "eq_params" or "both". Those that should not
222
- be updated will have a `jax.lax.stop_gradient` called on them. Default
223
- is `"nn_params"` for each composant of the loss.
224
- initial_condition : tuple[float | Float[Array, " 1"]], default=None
225
- tuple of length 2 with initial condition $(t_0, u_0)$.
226
- obs_slice : EllipsisType | slice | None, default=None
227
- Slice object specifying the begininning/ending
228
- slice of u output(s) that is observed. This is useful for
229
- multidimensional PINN, with partially observed outputs.
230
- Default is None (whole output is observed).
231
- params : InitVar[Params[Array]], default=None
232
- The main Params object of the problem needed to instanciate the
233
- DerivativeKeysODE if the latter is not specified.
234
- u : eqx.Module
235
- the PINN
236
- dynamic_loss : ODE
237
- the ODE dynamic part of the loss, basically the differential
238
- operator $\mathcal{N}[u](t)$. Should implement a method
239
- `dynamic_loss.evaluate(t, u, params)`.
240
- Can be None in order to access only some part of the evaluate call.
241
-
242
- Raises
243
- ------
244
- ValueError
245
- if initial condition is not a tuple.
246
- """
247
-
248
- # NOTE static=True only for leaf attributes that are not valid JAX types
249
- # (ie. jax.Array cannot be static) and that we do not expect to change
250
- u: AbstractPINN
251
- dynamic_loss: ODE | None
252
-
253
- vmap_in_axes: tuple[int] = eqx.field(init=False, static=True)
254
-
255
- def __post_init__(self, params: Params[Array] | None = None):
256
- super().__post_init__(
257
- params=params
258
- ) # because __init__ or __post_init__ of Base
259
- # class is not automatically called
260
-
261
- self.vmap_in_axes = (0,)
262
-
263
- def __call__(self, *args, **kwargs):
264
- return self.evaluate(*args, **kwargs)
219
+ else:
220
+ self.obs_slice = obs_slice
265
221
 
266
222
  def evaluate_by_terms(
267
223
  self, params: Params[Array], batch: ODEBatch
@@ -291,46 +247,54 @@ class LossODE(_LossODEAbstract):
291
247
  # and update vmap_in_axes
292
248
  if batch.param_batch_dict is not None:
293
249
  # update params with the batches of generated params
294
- params = _update_eq_params_dict(params, batch.param_batch_dict)
250
+ params = update_eq_params(params, batch.param_batch_dict)
295
251
 
296
- vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
252
+ vmap_in_axes_params = _get_vmap_in_axes_params(
253
+ cast(eqx.Module, batch.param_batch_dict), params
254
+ )
297
255
 
298
256
  ## dynamic part
299
257
  if self.dynamic_loss is not None:
300
- dyn_loss_fun = lambda p: dynamic_loss_apply(
301
- self.dynamic_loss.evaluate, # type: ignore
302
- self.u,
303
- temporal_batch,
304
- _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
305
- self.vmap_in_axes + vmap_in_axes_params,
258
+ dyn_loss_eval = self.dynamic_loss.evaluate
259
+ dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
260
+ lambda p: dynamic_loss_apply(
261
+ dyn_loss_eval,
262
+ self.u,
263
+ temporal_batch,
264
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
265
+ self.vmap_in_axes + vmap_in_axes_params,
266
+ )
306
267
  )
307
268
  else:
308
269
  dyn_loss_fun = None
309
270
 
310
- # initial condition
311
271
  if self.initial_condition is not None:
272
+ # initial condition
312
273
  t0, u0 = self.initial_condition
313
- u0 = jnp.array(u0)
314
274
 
315
275
  # first construct the plain init loss no vmaping
316
- initial_condition_fun__ = lambda t, u, p: jnp.sum(
317
- (
318
- self.u(
319
- t,
320
- _set_derivatives(
321
- p,
322
- self.derivative_keys.initial_condition, # type: ignore
323
- ),
276
+ initial_condition_fun__: Callable[[Array, Array, Params[Array]], Array] = (
277
+ lambda t, u, p: jnp.sum(
278
+ (
279
+ self.u(
280
+ t,
281
+ _set_derivatives(
282
+ p,
283
+ self.derivative_keys.initial_condition,
284
+ ),
285
+ )
286
+ - u
324
287
  )
325
- - u
288
+ ** 2,
289
+ axis=0,
326
290
  )
327
- ** 2,
328
- axis=0,
329
291
  )
330
292
  # now vmap over the number of conditions (first dim of t0 and u0)
331
293
  # and take the mean
332
- initial_condition_fun_ = lambda p: jnp.mean(
333
- vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
294
+ initial_condition_fun_: Callable[[Params[Array]], Array] = (
295
+ lambda p: jnp.mean(
296
+ vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
297
+ )
334
298
  )
335
299
  # now vmap over the the possible batch of parameters and take the
336
300
  # average. Note that we then finally have a cartesian product
@@ -342,26 +306,33 @@ class LossODE(_LossODEAbstract):
342
306
  # None in_axes or out_axes
343
307
  initial_condition_fun = initial_condition_fun_
344
308
  else:
345
- initial_condition_fun = lambda p: jnp.mean(
346
- vmap(initial_condition_fun_, vmap_in_axes_params)(p)
309
+ initial_condition_fun: Callable[[Params[Array]], Array] | None = (
310
+ lambda p: jnp.mean(
311
+ vmap(initial_condition_fun_, vmap_in_axes_params)(p)
312
+ )
347
313
  )
348
314
  else:
349
315
  initial_condition_fun = None
350
316
 
351
317
  if batch.obs_batch_dict is not None:
352
318
  # update params with the batches of observed params
353
- params_obs = _update_eq_params_dict(
354
- params, batch.obs_batch_dict["eq_params"]
355
- )
319
+ params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
356
320
 
357
- # MSE loss wrt to an observed batch
358
- obs_loss_fun = lambda po: observations_loss_apply(
359
- self.u,
321
+ pinn_in, val = (
360
322
  batch.obs_batch_dict["pinn_in"],
361
- _set_derivatives(po, self.derivative_keys.observations), # type: ignore
362
- self.vmap_in_axes + vmap_in_axes_params,
363
323
  batch.obs_batch_dict["val"],
364
- self.obs_slice,
324
+ ) # the reason for this intruction is https://github.com/microsoft/pyright/discussions/8340
325
+
326
+ # MSE loss wrt to an observed batch
327
+ obs_loss_fun: Callable[[Params[Array]], Array] | None = (
328
+ lambda po: observations_loss_apply(
329
+ self.u,
330
+ pinn_in,
331
+ _set_derivatives(po, self.derivative_keys.observations),
332
+ self.vmap_in_axes + vmap_in_axes_params,
333
+ val,
334
+ self.obs_slice,
335
+ )
365
336
  )
366
337
  else:
367
338
  params_obs = None
@@ -374,43 +345,27 @@ class LossODE(_LossODEAbstract):
374
345
  all_params: ODEComponents[Params[Array] | None] = ODEComponents(
375
346
  params, params, params_obs
376
347
  )
348
+
349
+ # Note that the lambda functions below are with type: ignore just
350
+ # because the lambda are not type annotated, but there is no proper way
351
+ # to do this and we should assign the lambda to a type hinted variable
352
+ # before hand: this is not practical, let us not get mad at this
377
353
  mses_grads = jax.tree.map(
378
- lambda fun, params: self.get_gradients(fun, params),
354
+ self.get_gradients,
379
355
  all_funs,
380
356
  all_params,
381
357
  is_leaf=lambda x: x is None,
382
358
  )
383
359
 
384
360
  mses = jax.tree.map(
385
- lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
361
+ lambda leaf: leaf[0], # type: ignore
362
+ mses_grads,
363
+ is_leaf=lambda x: isinstance(x, tuple),
386
364
  )
387
365
  grads = jax.tree.map(
388
- lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
366
+ lambda leaf: leaf[1], # type: ignore
367
+ mses_grads,
368
+ is_leaf=lambda x: isinstance(x, tuple),
389
369
  )
390
370
 
391
371
  return mses, grads
392
-
393
- def evaluate(
394
- self, params: Params[Array], batch: ODEBatch
395
- ) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
396
- """
397
- Evaluate the loss function at a batch of points for given parameters.
398
-
399
- We retrieve the total value itself and a PyTree with loss values for each term
400
-
401
- Parameters
402
- ---------
403
- params
404
- Parameters at which the loss is evaluated
405
- batch
406
- Composed of a batch of points in the
407
- domain, a batch of points in the domain
408
- border and an optional additional batch of parameters (eg. for
409
- metamodeling) and an optional additional batch of observed
410
- inputs/outputs/parameters
411
- """
412
- loss_terms, _ = self.evaluate_by_terms(params, batch)
413
-
414
- loss_val = self.ponderate_and_sum_loss(loss_terms)
415
-
416
- return loss_val, loss_terms