jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py CHANGED
@@ -1,143 +1,185 @@
1
+ # pylint: disable=unsubscriptable-object, no-member
1
2
  """
2
3
  Main module to implement a ODE loss in jinns
3
4
  """
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
4
8
 
9
+ from dataclasses import InitVar, fields
10
+ from typing import TYPE_CHECKING, Dict
11
+ import abc
5
12
  import warnings
6
13
  import jax
7
14
  import jax.numpy as jnp
8
15
  from jax import vmap
9
- from jax.tree_util import register_pytree_node_class
10
- from jinns.utils._utils import (
11
- _get_vmap_in_axes_params,
12
- _set_derivatives,
13
- _update_eq_params_dict,
14
- )
15
- from jinns.loss._Losses import (
16
+ import equinox as eqx
17
+ from jaxtyping import Float, Array, Int
18
+ from jinns.data._DataGenerators import append_obs_batch
19
+ from jinns.loss._loss_utils import (
16
20
  dynamic_loss_apply,
17
21
  constraints_system_loss_apply,
18
22
  observations_loss_apply,
19
23
  )
24
+ from jinns.parameters._params import (
25
+ _get_vmap_in_axes_params,
26
+ _update_eq_params_dict,
27
+ )
28
+ from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
29
+ from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
30
+ from jinns.loss._DynamicLossAbstract import ODE
20
31
  from jinns.utils._pinn import PINN
21
32
 
22
- _LOSS_WEIGHT_KEYS_ODE = ["observations", "dyn_loss", "initial_condition"]
23
-
24
-
25
- @register_pytree_node_class
26
- class LossODE:
27
- r"""Loss object for an ordinary differential equation
33
+ if TYPE_CHECKING:
34
+ from jinns.utils._types import *
28
35
 
29
- .. math::
30
- \mathcal{N}[u](t) = 0, \forall t \in I
31
36
 
32
- where :math:`\mathcal{N}[\cdot]` is a differential operator and the
33
- initial condition is :math:`u(t_0)=u_0`.
34
-
35
-
36
- **Note:** LossODE is jittable. Hence it implements the tree_flatten() and
37
- tree_unflatten methods.
37
+ class _LossODEAbstract(eqx.Module):
38
+ """
39
+ Parameters
40
+ ----------
41
+
42
+ loss_weights : LossWeightsODE, default=None
43
+ The loss weights for the differents term : dynamic loss,
44
+ initial condition and eventually observations if any. All fields are
45
+ set to 1.0 by default.
46
+ derivative_keys : DerivativeKeysODE, default=None
47
+ Specify which field of `params` should be differentiated for each
48
+ composant of the total loss. Particularily useful for inverse problems.
49
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
50
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
51
+ is `"nn_params"` for each composant of the loss.
52
+ initial_condition : tuple, default=None
53
+ tuple of length 2 with initial condition $(t_0, u_0)$.
54
+ obs_slice : Slice, default=None
55
+ Slice object specifying the begininning/ending
56
+ slice of u output(s) that is observed. This is useful for
57
+ multidimensional PINN, with partially observed outputs.
58
+ Default is None (whole output is observed).
59
+ params : InitVar[Params], default=None
60
+ The main Params object of the problem needed to instanciate the
61
+ DerivativeKeysODE if the latter is not specified.
38
62
  """
39
63
 
40
- def __init__(
41
- self,
42
- u,
43
- loss_weights,
44
- dynamic_loss,
45
- derivative_keys=None,
46
- initial_condition=None,
47
- obs_slice=None,
48
- ):
49
- r"""
50
- Parameters
51
- ----------
52
- u :
53
- the PINN
54
- loss_weights :
55
- a dictionary with values used to ponderate each term in the loss
56
- function. Valid keys are `dyn_loss`, `initial_condition` and `observations`
57
- Note that we can have jnp.arrays with the same dimension of
58
- `u` which then ponderates each output of `u`
59
- dynamic_loss :
60
- the ODE dynamic part of the loss, basically the differential
61
- operator :math:`\mathcal{N}[u](t)`. Should implement a method
62
- `dynamic_loss.evaluate(t, u, params)`.
63
- Can be None in order to
64
- access only some part of the evaluate call results.
65
- derivative_keys
66
- A dict of lists of strings. In the dict, the key must correspond to
67
- the loss term keywords. Then each of the values must correspond to keys in the parameter
68
- dictionary (*at top level only of the parameter dictionary*).
69
- It enables selecting the set of parameters
70
- with respect to which the gradients of the dynamic
71
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
72
- keywords, this is what is typically
73
- done in solving forward problems, when we only estimate the
74
- equation solution with a PINN. If some loss terms keywords are
75
- missing we set their value to ["nn_params"] by default for the same
76
- reason
77
- initial_condition :
78
- tuple of length 2 with initial condition :math:`(t0, u0)`.
79
- Can be None in order to
80
- access only some part of the evaluate call results.
81
- obs_slice:
82
- slice object specifying the begininning/ending
83
- slice of u output(s) that is observed (this is then useful for
84
- multidim PINN). Default is None.
85
-
86
- Raises
87
- ------
88
- ValueError
89
- if initial condition is not a tuple.
90
- """
91
- self.dynamic_loss = dynamic_loss
92
- self.u = u
93
- if derivative_keys is None:
94
- # be default we only take gradient wrt nn_params
95
- derivative_keys = {
96
- k: ["nn_params"]
97
- for k in [
98
- "dyn_loss",
99
- "initial_condition",
100
- "observations",
101
- ]
102
- }
103
- if isinstance(derivative_keys, list):
104
- # if the user only provided a list, this defines the gradient taken
105
- # for all the loss entries
106
- derivative_keys = {
107
- k: derivative_keys
108
- for k in [
109
- "dyn_loss",
110
- "initial_condition",
111
- "observations",
112
- ]
113
- }
114
-
115
- self.derivative_keys = derivative_keys
116
-
117
- if initial_condition is None:
64
+ # NOTE static=True only for leaf attributes that are not valid JAX types
65
+ # (ie. jax.Array cannot be static) and that we do not expect to change
66
+ # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
67
+ derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
68
+ loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
69
+ initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
70
+ obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
71
+
72
+ params: InitVar[Params] = eqx.field(default=None, kw_only=True)
73
+
74
+ def __post_init__(self, params=None):
75
+ if self.loss_weights is None:
76
+ self.loss_weights = LossWeightsODE()
77
+
78
+ if self.derivative_keys is None:
79
+ try:
80
+ # be default we only take gradient wrt nn_params
81
+ self.derivative_keys = DerivativeKeysODE(params=params)
82
+ except ValueError as exc:
83
+ raise ValueError(
84
+ "Problem at self.derivative_keys initialization "
85
+ f"received {self.derivative_keys=} and {params=}"
86
+ ) from exc
87
+ if self.initial_condition is None:
118
88
  warnings.warn(
119
89
  "Initial condition wasn't provided. Be sure to cover for that"
120
90
  "case (e.g by. hardcoding it into the PINN output)."
121
91
  )
122
92
  else:
123
- if not isinstance(initial_condition, tuple) or len(initial_condition) != 2:
93
+ if (
94
+ not isinstance(self.initial_condition, tuple)
95
+ or len(self.initial_condition) != 2
96
+ ):
124
97
  raise ValueError(
125
- f"Initial condition should be a tuple of len 2 with (t0, u0), {initial_condition} was passed."
98
+ "Initial condition should be a tuple of len 2 with (t0, u0), "
99
+ f"{self.initial_condition} was passed."
126
100
  )
127
- self.initial_condition = initial_condition
128
- self.loss_weights = loss_weights
129
- self.obs_slice = obs_slice
101
+
130
102
  if self.obs_slice is None:
131
103
  self.obs_slice = jnp.s_[...]
132
104
 
133
- for k in _LOSS_WEIGHT_KEYS_ODE:
134
- if k not in self.loss_weights.keys():
135
- self.loss_weights[k] = 0
105
+ if self.loss_weights is None:
106
+ self.loss_weights = LossWeightsODE()
107
+
108
+ @abc.abstractmethod
109
+ def evaluate(
110
+ self: eqx.Module, params: Params, batch: ODEBatch
111
+ ) -> tuple[Float, dict]:
112
+ raise NotImplementedError
113
+
114
+
115
+ class LossODE(_LossODEAbstract):
116
+ r"""Loss object for an ordinary differential equation
117
+
118
+ $$
119
+ \mathcal{N}[u](t) = 0, \forall t \in I
120
+ $$
121
+
122
+ where $\mathcal{N}[\cdot]$ is a differential operator and the
123
+ initial condition is $u(t_0)=u_0$.
124
+
125
+
126
+ Parameters
127
+ ----------
128
+ loss_weights : LossWeightsODE, default=None
129
+ The loss weights for the differents term : dynamic loss,
130
+ initial condition and eventually observations if any. All fields are
131
+ set to 1.0 by default.
132
+ derivative_keys : DerivativeKeysODE, default=None
133
+ Specify which field of `params` should be differentiated for each
134
+ composant of the total loss. Particularily useful for inverse problems.
135
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
136
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
137
+ is `"nn_params"` for each composant of the loss.
138
+ initial_condition : tuple, default=None
139
+ tuple of length 2 with initial condition $(t_0, u_0)$.
140
+ obs_slice Slice, default=None
141
+ Slice object specifying the begininning/ending
142
+ slice of u output(s) that is observed. This is useful for
143
+ multidimensional PINN, with partially observed outputs.
144
+ Default is None (whole output is observed).
145
+ params : InitVar[Params], default=None
146
+ The main Params object of the problem needed to instanciate the
147
+ DerivativeKeysODE if the latter is not specified.
148
+ u : eqx.Module
149
+ the PINN
150
+ dynamic_loss : DynamicLoss
151
+ the ODE dynamic part of the loss, basically the differential
152
+ operator $\mathcal{N}[u](t)$. Should implement a method
153
+ `dynamic_loss.evaluate(t, u, params)`.
154
+ Can be None in order to access only some part of the evaluate call.
155
+
156
+ Raises
157
+ ------
158
+ ValueError
159
+ if initial condition is not a tuple.
160
+ """
161
+
162
+ # NOTE static=True only for leaf attributes that are not valid JAX types
163
+ # (ie. jax.Array cannot be static) and that we do not expect to change
164
+ u: eqx.Module
165
+ dynamic_loss: DynamicLoss | None
166
+
167
+ vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
168
+
169
+ def __post_init__(self, params=None):
170
+ super().__post_init__(
171
+ params=params
172
+ ) # because __init__ or __post_init__ of Base
173
+ # class is not automatically called
174
+
175
+ self.vmap_in_axes = (0,)
136
176
 
137
177
  def __call__(self, *args, **kwargs):
138
178
  return self.evaluate(*args, **kwargs)
139
179
 
140
- def evaluate(self, params, batch):
180
+ def evaluate(
181
+ self, params: Params, batch: ODEBatch
182
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
141
183
  """
142
184
  Evaluate the loss function at a batch of points for given parameters.
143
185
 
@@ -145,21 +187,14 @@ class LossODE:
145
187
  Parameters
146
188
  ---------
147
189
  params
148
- The dictionary of parameters of the model.
149
- Typically, it is a dictionary of
150
- dictionaries: `eq_params` and `nn_params``, respectively the
151
- differential equation parameters and the neural network parameter
190
+ Parameters at which the loss is evaluated
152
191
  batch
153
- A ODEBatch object.
154
- Such a named tuple is composed of a batch of time points
155
- at which to evaluate an optional additional batch of parameters (eg. for
156
- metamodeling) and an optional additional batch of observed
157
- inputs/outputs/parameters
192
+ Composed of a batch of time points
193
+ at which to evaluate the differential operator. An optional additional batch of parameters (eg. for metamodeling) and an optional additional batch of observed inputs/outputs/parameters can
194
+ be supplied.
158
195
  """
159
196
  temporal_batch = batch.temporal_batch
160
197
 
161
- vmap_in_axes_t = (0,)
162
-
163
198
  # Retrieve the optional eq_params_batch
164
199
  # and update eq_params with the latter
165
200
  # and update vmap_in_axes
@@ -170,21 +205,19 @@ class LossODE:
170
205
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
171
206
 
172
207
  ## dynamic part
173
- params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
174
208
  if self.dynamic_loss is not None:
175
209
  mse_dyn_loss = dynamic_loss_apply(
176
210
  self.dynamic_loss.evaluate,
177
211
  self.u,
178
212
  (temporal_batch,),
179
- params_,
180
- vmap_in_axes_t + vmap_in_axes_params,
181
- self.loss_weights["dyn_loss"],
213
+ _set_derivatives(params, self.derivative_keys.dyn_loss),
214
+ self.vmap_in_axes + vmap_in_axes_params,
215
+ self.loss_weights.dyn_loss,
182
216
  )
183
217
  else:
184
218
  mse_dyn_loss = jnp.array(0.0)
185
219
 
186
220
  # initial condition
187
- params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
188
221
  if self.initial_condition is not None:
189
222
  vmap_in_axes = (None,) + vmap_in_axes_params
190
223
  if not jax.tree_util.tree_leaves(vmap_in_axes):
@@ -193,12 +226,24 @@ class LossODE:
193
226
  v_u = self.u
194
227
  else:
195
228
  v_u = vmap(self.u, (None,) + vmap_in_axes_params)
196
- t0, u0 = self.initial_condition
229
+ t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
197
230
  t0 = jnp.array(t0)
198
231
  u0 = jnp.array(u0)
199
232
  mse_initial_condition = jnp.mean(
200
- self.loss_weights["initial_condition"]
201
- * jnp.sum((v_u(t0, params_) - u0) ** 2, axis=-1)
233
+ self.loss_weights.initial_condition
234
+ * jnp.sum(
235
+ (
236
+ v_u(
237
+ t0,
238
+ _set_derivatives(
239
+ params, self.derivative_keys.initial_condition
240
+ ),
241
+ )
242
+ - u0
243
+ )
244
+ ** 2,
245
+ axis=-1,
246
+ )
202
247
  )
203
248
  else:
204
249
  mse_initial_condition = jnp.array(0.0)
@@ -208,14 +253,13 @@ class LossODE:
208
253
  params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
209
254
 
210
255
  # MSE loss wrt to an observed batch
211
- params_ = _set_derivatives(params, "observations", self.derivative_keys)
212
256
  mse_observation_loss = observations_loss_apply(
213
257
  self.u,
214
258
  (batch.obs_batch_dict["pinn_in"],),
215
- params_,
216
- vmap_in_axes_t + vmap_in_axes_params,
259
+ _set_derivatives(params, self.derivative_keys.observations),
260
+ self.vmap_in_axes + vmap_in_axes_params,
217
261
  batch.obs_batch_dict["val"],
218
- self.loss_weights["observations"],
262
+ self.loss_weights.observations,
219
263
  self.obs_slice,
220
264
  )
221
265
  else:
@@ -231,193 +275,178 @@ class LossODE:
231
275
  }
232
276
  )
233
277
 
234
- def tree_flatten(self):
235
- children = (self.initial_condition, self.loss_weights)
236
- aux_data = {
237
- "u": self.u,
238
- "dynamic_loss": self.dynamic_loss,
239
- "obs_slice": self.obs_slice,
240
- "derivative_keys": self.derivative_keys,
241
- }
242
- return (children, aux_data)
243
-
244
- @classmethod
245
- def tree_unflatten(cls, aux_data, children):
246
- (initial_condition, loss_weights) = children
247
- loss_ode = cls(
248
- loss_weights=loss_weights,
249
- initial_condition=initial_condition,
250
- **aux_data,
251
- )
252
- return loss_ode
253
278
 
254
-
255
- @register_pytree_node_class
256
- class SystemLossODE:
257
- """
279
+ class SystemLossODE(eqx.Module):
280
+ r"""
258
281
  Class to implement a system of ODEs.
259
282
  The goal is to give maximum freedom to the user. The class is created with
260
- a dict of dynamic loss and a dict of initial conditions. When then iterate
261
- over the dynamic losses that compose the system. All the PINNs with all the
262
- parameter dictionaries are passed as arguments to each dynamic loss
263
- evaluate functions; it is inside the dynamic loss that specification are
264
- performed.
283
+ a dict of dynamic loss and a dict of initial conditions. Then, it iterates
284
+ over the dynamic losses that compose the system. All PINNs are passed as
285
+ arguments to each dynamic loss evaluate functions, along with all the
286
+ parameter dictionaries. All specification is left to the responsability
287
+ of the user, inside the dynamic loss.
265
288
 
266
289
  **Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
267
290
  Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
268
291
  solution.
269
292
 
270
- **Note:** SystemLossODE is jittable. Hence it implements the tree_flatten() and
271
- tree_unflatten methods.
293
+ Parameters
294
+ ----------
295
+ u_dict : Dict[str, eqx.Module]
296
+ dict of PINNs
297
+ loss_weights : LossWeightsODEDict
298
+ A dictionary of LossWeightsODE
299
+ derivative_keys_dict : Dict[str, DerivativeKeysODE], default=None
300
+ A dictionnary of DerivativeKeysODE specifying what field of `params`
301
+ should be used during gradient computations for each of the terms of
302
+ the total loss, for each of the loss in the system. Default is
303
+ `"nn_params`" everywhere.
304
+ initial_condition_dict : Dict[str, tuple], default=None
305
+ dict of tuple of length 2 with initial condition $(t_0, u_0)$
306
+ Must share the keys of `u_dict`. Default is None. No initial
307
+ condition is permitted when the initial condition is hardcoded in
308
+ the PINN architecture for example
309
+ dynamic_loss_dict : Dict[str, ODE]
310
+ dict of dynamic part of the loss, basically the differential
311
+ operator $\mathcal{N}[u](t)$. Should implement a method
312
+ `dynamic_loss.evaluate(t, u, params)`
313
+ obs_slice_dict : Dict[str, Slice]
314
+ dict of obs_slice, with keys from `u_dict` to designate the
315
+ output(s) channels that are observed, for each
316
+ PINNs. Default is None. But if a value is given, all the entries of
317
+ `u_dict` must be represented here with default value `jnp.s_[...]`
318
+ if no particular slice is to be given.
319
+ params_dict : InitVar[ParamsDict], default=None
320
+ The main Params object of the problem needed to instanciate the
321
+ DerivativeKeysODE if the latter is not specified.
322
+
323
+ Raises
324
+ ------
325
+ ValueError
326
+ if initial condition is not a dict of tuple.
327
+ ValueError
328
+ if the dictionaries that should share the keys of u_dict do not.
272
329
  """
273
330
 
274
- def __init__(
275
- self,
276
- u_dict,
277
- loss_weights,
278
- dynamic_loss_dict,
279
- derivative_keys_dict=None,
280
- initial_condition_dict=None,
281
- obs_slice_dict=None,
282
- ):
283
- r"""
284
- Parameters
285
- ----------
286
- u_dict
287
- dict of PINNs
288
- loss_weights
289
- A dictionary of dictionaries with values used to
290
- ponderate each term in the loss
291
- function. Valid keys in the first dictionary are `dyn_loss`,
292
- `initial_condition` and `observations`. The keys of the nested
293
- dictionaries must share the keys of `u_dict`. Note that the values
294
- at the leaf level can have jnp.arrays with the same dimension of
295
- `u` which then ponderates each output of `u`
296
- derivative_keys_dict
297
- A dict of derivative keys as defined in LossODE. The key of this
298
- dict must be that of `dynamic_loss_dict` at least and specify how
299
- to compute gradient for the `dyn_loss` loss term at least (see the
300
- check at the beginning of the present `__init__` function.
301
- Other keys of this dict might be that of `u_dict` to specify how to
302
- compute gradients for all the different constraints. If those keys
303
- are not specified then the default behaviour for `derivative_keys`
304
- of LossODE is used
305
- initial_condition_dict
306
- dict of tuple of length 2 with initial condition :math:`(t_0, u_0)`
307
- Must share the keys of `u_dict`. Default is None. No initial
308
- condition is permitted when the initial condition is hardcoded in
309
- the PINN architecture for example
310
- dynamic_loss_dict
311
- dict of dynamic part of the loss, basically the differential
312
- operator :math:`\mathcal{N}[u](t)`. Should implement a method
313
- `dynamic_loss.evaluate(t, u, params)`
314
- obs_slice_dict
315
- dict of obs_slice, with keys from `u_dict` to designate the
316
- output(s) channels that are forced to observed values, for each
317
- PINNs. Default is None. But if a value is given, all the entries of
318
- `u_dict` must be represented here with default value `jnp.s_[...]`
319
- if no particular slice is to be given
320
-
321
- Raises
322
- ------
323
- ValueError
324
- if initial condition is not a dict of tuple
325
- ValueError
326
- if the dictionaries that should share the keys of u_dict do not
327
- """
328
-
331
+ # NOTE static=True only for leaf attributes that are not valid JAX types
332
+ # (ie. jax.Array cannot be static) and that we do not expect to change
333
+ u_dict: Dict[str, eqx.Module]
334
+ dynamic_loss_dict: Dict[str, ODE]
335
+ derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
336
+ kw_only=True, default=None
337
+ )
338
+ initial_condition_dict: Dict[str, tuple] | None = eqx.field(
339
+ kw_only=True, default=None
340
+ )
341
+
342
+ obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
343
+ kw_only=True, default=None, static=True
344
+ ) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
345
+ # we do not expect it to change with put a static=True here. But note that
346
+ # this is the only static for all the SystemLossODE attribute, since all
347
+ # other are composed of more complex structures ("non-leaf")
348
+
349
+ # For the user loss_weights are passed as a LossWeightsODEDict (with internal
350
+ # dictionary having keys in u_dict and / or dynamic_loss_dict)
351
+ loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
352
+ kw_only=True, default=None
353
+ )
354
+ params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
355
+
356
+ u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
357
+ derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
358
+
359
+ u_dict_with_none: Dict[str, None] = eqx.field(init=False)
360
+ # internally the loss weights are handled with a dictionary
361
+ _loss_weights: Dict[str, dict] = eqx.field(init=False)
362
+
363
+ def __post_init__(self, loss_weights=None, params_dict=None):
329
364
  # a dictionary that will be useful at different places
330
- self.u_dict_with_none = {k: None for k in u_dict.keys()}
331
- if initial_condition_dict is None:
365
+ self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
366
+ if self.initial_condition_dict is None:
332
367
  self.initial_condition_dict = self.u_dict_with_none
333
368
  else:
334
- self.initial_condition_dict = initial_condition_dict
335
- if u_dict.keys() != initial_condition_dict.keys():
369
+ if self.u_dict.keys() != self.initial_condition_dict.keys():
336
370
  raise ValueError(
337
371
  "initial_condition_dict should have same keys as u_dict"
338
372
  )
339
- if obs_slice_dict is None:
340
- self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
373
+ if self.obs_slice_dict is None:
374
+ self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
341
375
  else:
342
- self.obs_slice_dict = obs_slice_dict
343
- if u_dict.keys() != obs_slice_dict.keys():
376
+ if self.u_dict.keys() != self.obs_slice_dict.keys():
344
377
  raise ValueError("obs_slice_dict should have same keys as u_dict")
345
378
 
346
- if derivative_keys_dict is None:
379
+ if self.derivative_keys_dict is None:
347
380
  self.derivative_keys_dict = {
348
381
  k: None
349
- for k in set(list(dynamic_loss_dict.keys()) + list(u_dict.keys()))
382
+ for k in set(
383
+ list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
384
+ )
350
385
  }
351
386
  # set() because we can have duplicate entries and in this case we
352
387
  # say it corresponds to the same derivative_keys_dict entry
353
- else:
354
- self.derivative_keys_dict = derivative_keys_dict
355
-
356
- # but then if the user did not provide anything, we must at least have
357
- # a default value for the dynamic_loss_dict keys entries in
358
- # self.derivative_keys_dict since the computation of dynamic losses is
359
- # made without create a lossODE object that would provide the
360
- # default values
361
- for k in dynamic_loss_dict.keys():
388
+ # we need both because the constraints (all but dyn_loss) will be
389
+ # done by iterating on u_dict while the dyn_loss will be by
390
+ # iterating on dynamic_loss_dict. So each time we will require dome
391
+ # derivative_keys_dict
392
+
393
+ # derivative keys for the u_constraints. Note that we create missing
394
+ # DerivativeKeysODE around a Params object and not ParamsDict
395
+ # this works because u_dict.keys == params_dict.nn_params.keys()
396
+ for k in self.u_dict.keys():
362
397
  if self.derivative_keys_dict[k] is None:
363
- self.derivative_keys_dict[k] = {"dyn_loss": ["nn_params"]}
364
-
365
- self.dynamic_loss_dict = dynamic_loss_dict
366
- self.u_dict = u_dict
398
+ self.derivative_keys_dict[k] = DerivativeKeysODE(
399
+ params=params_dict.extract_params(k)
400
+ )
367
401
 
368
- self.loss_weights = loss_weights # We call the setter
369
- # note that self.initial_condition_dict must be
370
- # initialized beforehand
402
+ self._loss_weights = self.set_loss_weights(loss_weights)
371
403
 
372
404
  # The constaints on the solutions will be implemented by reusing a
373
405
  # LossODE class without dynamic loss term
374
406
  self.u_constraints_dict = {}
375
407
  for i in self.u_dict.keys():
376
408
  self.u_constraints_dict[i] = LossODE(
377
- u=u_dict[i],
378
- loss_weights={
379
- "dyn_loss": 0.0,
380
- "initial_condition": 1.0,
381
- "observations": 1.0,
382
- },
409
+ u=self.u_dict[i],
410
+ loss_weights=LossWeightsODE(
411
+ dyn_loss=0.0,
412
+ initial_condition=1.0,
413
+ observations=1.0,
414
+ ),
383
415
  dynamic_loss=None,
384
416
  derivative_keys=self.derivative_keys_dict[i],
385
417
  initial_condition=self.initial_condition_dict[i],
386
418
  obs_slice=self.obs_slice_dict[i],
387
419
  )
388
420
 
389
- # for convenience in the tree_map of evaluate,
390
- # we separate the two derivative keys dict
391
- self.derivative_keys_dyn_loss_dict = {
392
- k: self.derivative_keys_dict[k]
393
- for k in self.dynamic_loss_dict.keys() & self.derivative_keys_dict.keys()
394
- }
395
- self.derivative_keys_u_dict = {
396
- k: self.derivative_keys_dict[k]
397
- for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
398
- }
399
-
400
- @property
401
- def loss_weights(self):
402
- return self._loss_weights
421
+ # derivative keys for the dynamic loss. Note that we create a
422
+ # DerivativeKeysODE around a ParamsDict object because a whole
423
+ # params_dict is feed to DynamicLoss.evaluate functions (extract_params
424
+ # happen inside it)
425
+ self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
403
426
 
404
- @loss_weights.setter
405
- def loss_weights(self, value):
406
- self._loss_weights = {}
407
- for k, v in value.items():
427
+ def set_loss_weights(self, loss_weights_init):
428
+ """
429
+ This rather complex function enables the user to specify a simple
430
+ loss_weights=LossWeightsODEDict(dyn_loss=1., initial_condition=Tmax)
431
+ for ponderating values being applied to all the equations of the
432
+ system... So all the transformations are handled here
433
+ """
434
+ _loss_weights = {}
435
+ for k in fields(loss_weights_init):
436
+ v = getattr(loss_weights_init, k.name)
408
437
  if isinstance(v, dict):
409
- for kk, vv in v.items():
438
+ for vv in v.values():
410
439
  if not isinstance(vv, (int, float)) and not (
411
- isinstance(vv, jnp.ndarray)
440
+ isinstance(vv, Array)
412
441
  and ((vv.shape == (1,) or len(vv.shape) == 0))
413
442
  ):
414
443
  # TODO improve that
415
444
  raise ValueError(
416
445
  f"loss values cannot be vectorial here, got {vv}"
417
446
  )
418
- if k == "dyn_loss":
447
+ if k.name == "dyn_loss":
419
448
  if v.keys() == self.dynamic_loss_dict.keys():
420
- self._loss_weights[k] = v
449
+ _loss_weights[k.name] = v
421
450
  else:
422
451
  raise ValueError(
423
452
  "Keys in nested dictionary of loss_weights"
@@ -425,48 +454,41 @@ class SystemLossODE:
425
454
  )
426
455
  else:
427
456
  if v.keys() == self.u_dict.keys():
428
- self._loss_weights[k] = v
457
+ _loss_weights[k.name] = v
429
458
  else:
430
459
  raise ValueError(
431
460
  "Keys in nested dictionary of loss_weights"
432
461
  " do not match u_dict keys"
433
462
  )
463
+ elif v is None:
464
+ _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
434
465
  else:
435
466
  if not isinstance(v, (int, float)) and not (
436
- isinstance(v, jnp.ndarray)
437
- and ((v.shape == (1,) or len(v.shape) == 0))
467
+ isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
438
468
  ):
439
469
  # TODO improve that
440
470
  raise ValueError(f"loss values cannot be vectorial here, got {v}")
441
- if k == "dyn_loss":
442
- self._loss_weights[k] = {
471
+ if k.name == "dyn_loss":
472
+ _loss_weights[k.name] = {
443
473
  kk: v for kk in self.dynamic_loss_dict.keys()
444
474
  }
445
475
  else:
446
- self._loss_weights[k] = {kk: v for kk in self.u_dict.keys()}
447
- if all(v is None for k, v in self.initial_condition_dict.items()):
448
- self._loss_weights["initial_condition"] = {k: 0 for k in self.u_dict.keys()}
449
- if "observations" not in value.keys():
450
- self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
476
+ _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
477
+
478
+ return _loss_weights
451
479
 
452
480
  def __call__(self, *args, **kwargs):
453
481
  return self.evaluate(*args, **kwargs)
454
482
 
455
- def evaluate(self, params_dict, batch):
483
+ def evaluate(self, params_dict: ParamsDict, batch: ODEBatch) -> Float[Array, "1"]:
456
484
  """
457
485
  Evaluate the loss function at a batch of points for given parameters.
458
486
 
459
487
 
460
488
  Parameters
461
489
  ---------
462
- params_dict
463
- A dictionary of dictionaries of parameters of the model.
464
- Typically, it is a dictionary of dictionaries of
465
- dictionaries: `eq_params` and `nn_params``, respectively the
466
- differential equation parameters and the neural network parameter.
467
- Note that params_dict["nn_params"] need not be a dictionary anymore
468
- but can directly be the parameters. It is useful when working with
469
- neural networks sharing the same parameters
490
+ params
491
+ A ParamsDict object
470
492
  batch
471
493
  A ODEBatch object.
472
494
  Such a named tuple is composed of a batch of time points
@@ -475,10 +497,10 @@ class SystemLossODE:
475
497
  inputs/outputs/parameters
476
498
  """
477
499
  if (
478
- isinstance(params_dict["nn_params"], dict)
479
- and self.u_dict.keys() != params_dict["nn_params"].keys()
500
+ isinstance(params_dict.nn_params, dict)
501
+ and self.u_dict.keys() != params_dict.nn_params.keys()
480
502
  ):
481
- raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
503
+ raise ValueError("u_dict and params_dict.nn_params should have same keys ")
482
504
 
483
505
  temporal_batch = batch.temporal_batch
484
506
 
@@ -489,20 +511,19 @@ class SystemLossODE:
489
511
  # and update vmap_in_axes
490
512
  if batch.param_batch_dict is not None:
491
513
  # update params with the batches of generated params
492
- params_dict = _update_eq_params_dict(params_dict, batch.param_batch_dict)
514
+ params = _update_eq_params_dict(params, batch.param_batch_dict)
493
515
 
494
516
  vmap_in_axes_params = _get_vmap_in_axes_params(
495
517
  batch.param_batch_dict, params_dict
496
518
  )
497
519
 
498
- def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
520
+ def dyn_loss_for_one_key(dyn_loss, loss_weight):
499
521
  """This function is used in tree_map"""
500
- params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
501
522
  return dynamic_loss_apply(
502
523
  dyn_loss.evaluate,
503
524
  self.u_dict,
504
525
  (temporal_batch,),
505
- params_dict_,
526
+ _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
506
527
  vmap_in_axes_t + vmap_in_axes_params,
507
528
  loss_weight,
508
529
  u_type=PINN,
@@ -511,8 +532,12 @@ class SystemLossODE:
511
532
  dyn_loss_mse_dict = jax.tree_util.tree_map(
512
533
  dyn_loss_for_one_key,
513
534
  self.dynamic_loss_dict,
514
- self.derivative_keys_dyn_loss_dict,
515
535
  self._loss_weights["dyn_loss"],
536
+ is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
537
+ # where plain (unregister pytree) node classes, we could not traverse
538
+ # this level. Now that dynamic losses are eqx.Module they can be
539
+ # traversed by tree map recursion. Hence we need to specify to that
540
+ # we want to stop at this level
516
541
  )
517
542
  mse_dyn_loss = jax.tree_util.tree_reduce(
518
543
  lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
@@ -527,7 +552,8 @@ class SystemLossODE:
527
552
 
528
553
  # we need to do the following for the tree_mapping to work
529
554
  if batch.obs_batch_dict is None:
530
- batch = batch._replace(obs_batch_dict=self.u_dict_with_none)
555
+ batch = append_obs_batch(batch, self.u_dict_with_none)
556
+
531
557
  total_loss, res_dict = constraints_system_loss_apply(
532
558
  self.u_constraints_dict,
533
559
  batch,
@@ -540,27 +566,3 @@ class SystemLossODE:
540
566
  total_loss += mse_dyn_loss
541
567
  res_dict["dyn_loss"] += mse_dyn_loss
542
568
  return total_loss, res_dict
543
-
544
- def tree_flatten(self):
545
- children = (
546
- self.initial_condition_dict,
547
- self._loss_weights,
548
- )
549
- aux_data = {
550
- "u_dict": self.u_dict,
551
- "dynamic_loss_dict": self.dynamic_loss_dict,
552
- "derivative_keys_dict": self.derivative_keys_dict,
553
- "obs_slice_dict": self.obs_slice_dict,
554
- }
555
- return (children, aux_data)
556
-
557
- @classmethod
558
- def tree_unflatten(cls, aux_data, children):
559
- (initial_condition_dict, loss_weights) = children
560
- loss_ode = cls(
561
- loss_weights=loss_weights,
562
- initial_condition_dict=initial_condition_dict,
563
- **aux_data,
564
- )
565
-
566
- return loss_ode