jinns 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +292 -309
  10. jinns/loss/_LossPDE.py +625 -1010
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +95 -44
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/METADATA +4 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.9.0.dist-info/RECORD +0 -36
  41. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py CHANGED
@@ -1,143 +1,169 @@
1
+ # pylint: disable=unsubscriptable-object, no-member
1
2
  """
2
3
  Main module to implement a ODE loss in jinns
3
4
  """
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
4
8
 
9
+ from dataclasses import InitVar, fields
10
+ from typing import TYPE_CHECKING, Dict
11
+ import abc
5
12
  import warnings
6
13
  import jax
7
14
  import jax.numpy as jnp
8
15
  from jax import vmap
9
- from jax.tree_util import register_pytree_node_class
10
- from jinns.utils._utils import (
11
- _get_vmap_in_axes_params,
12
- _set_derivatives,
13
- _update_eq_params_dict,
14
- )
15
- from jinns.loss._Losses import (
16
+ import equinox as eqx
17
+ from jaxtyping import Float, Array, Int
18
+ from jinns.data._DataGenerators import append_obs_batch
19
+ from jinns.loss._loss_utils import (
16
20
  dynamic_loss_apply,
17
21
  constraints_system_loss_apply,
18
22
  observations_loss_apply,
19
23
  )
24
+ from jinns.parameters._params import (
25
+ _get_vmap_in_axes_params,
26
+ _update_eq_params_dict,
27
+ )
28
+ from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
29
+ from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
30
+ from jinns.loss._DynamicLossAbstract import ODE
20
31
  from jinns.utils._pinn import PINN
21
32
 
22
- _LOSS_WEIGHT_KEYS_ODE = ["observations", "dyn_loss", "initial_condition"]
23
-
24
-
25
- @register_pytree_node_class
26
- class LossODE:
27
- r"""Loss object for an ordinary differential equation
33
+ if TYPE_CHECKING:
34
+ from jinns.utils._types import *
28
35
 
29
- .. math::
30
- \mathcal{N}[u](t) = 0, \forall t \in I
31
36
 
32
- where :math:`\mathcal{N}[\cdot]` is a differential operator and the
33
- initial condition is :math:`u(t_0)=u_0`.
34
-
35
-
36
- **Note:** LossODE is jittable. Hence it implements the tree_flatten() and
37
- tree_unflatten methods.
37
+ class _LossODEAbstract(eqx.Module):
38
+ """
39
+ Parameters
40
+ ----------
41
+
42
+ loss_weights : LossWeightsODE, default=None
43
+ The loss weights for the differents term : dynamic loss,
44
+ initial condition and eventually observations if any. All fields are
45
+ set to 1.0 by default.
46
+ derivative_keys : DerivativeKeysODE, default=None
47
+ Specify which field of `params` should be differentiated for each
48
+ composant of the total loss. Particularily useful for inverse problems.
49
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
50
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
51
+ is `"nn_params"` for each composant of the loss.
52
+ initial_condition : tuple, default=None
53
+ tuple of length 2 with initial condition $(t_0, u_0)$.
54
+ obs_slice : Slice, default=None
55
+ Slice object specifying the begininning/ending
56
+ slice of u output(s) that is observed. This is useful for
57
+ multidimensional PINN, with partially observed outputs.
58
+ Default is None (whole output is observed).
38
59
  """
39
60
 
40
- def __init__(
41
- self,
42
- u,
43
- loss_weights,
44
- dynamic_loss,
45
- derivative_keys=None,
46
- initial_condition=None,
47
- obs_slice=None,
48
- ):
49
- r"""
50
- Parameters
51
- ----------
52
- u :
53
- the PINN
54
- loss_weights :
55
- a dictionary with values used to ponderate each term in the loss
56
- function. Valid keys are `dyn_loss`, `initial_condition` and `observations`
57
- Note that we can have jnp.arrays with the same dimension of
58
- `u` which then ponderates each output of `u`
59
- dynamic_loss :
60
- the ODE dynamic part of the loss, basically the differential
61
- operator :math:`\mathcal{N}[u](t)`. Should implement a method
62
- `dynamic_loss.evaluate(t, u, params)`.
63
- Can be None in order to
64
- access only some part of the evaluate call results.
65
- derivative_keys
66
- A dict of lists of strings. In the dict, the key must correspond to
67
- the loss term keywords. Then each of the values must correspond to keys in the parameter
68
- dictionary (*at top level only of the parameter dictionary*).
69
- It enables selecting the set of parameters
70
- with respect to which the gradients of the dynamic
71
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
72
- keywords, this is what is typically
73
- done in solving forward problems, when we only estimate the
74
- equation solution with a PINN. If some loss terms keywords are
75
- missing we set their value to ["nn_params"] by default for the same
76
- reason
77
- initial_condition :
78
- tuple of length 2 with initial condition :math:`(t0, u0)`.
79
- Can be None in order to
80
- access only some part of the evaluate call results.
81
- obs_slice:
82
- slice object specifying the begininning/ending
83
- slice of u output(s) that is observed (this is then useful for
84
- multidim PINN). Default is None.
85
-
86
- Raises
87
- ------
88
- ValueError
89
- if initial condition is not a tuple.
90
- """
91
- self.dynamic_loss = dynamic_loss
92
- self.u = u
93
- if derivative_keys is None:
94
- # be default we only take gradient wrt nn_params
95
- derivative_keys = {
96
- k: ["nn_params"]
97
- for k in [
98
- "dyn_loss",
99
- "initial_condition",
100
- "observations",
101
- ]
102
- }
103
- if isinstance(derivative_keys, list):
104
- # if the user only provided a list, this defines the gradient taken
105
- # for all the loss entries
106
- derivative_keys = {
107
- k: derivative_keys
108
- for k in [
109
- "dyn_loss",
110
- "initial_condition",
111
- "observations",
112
- ]
113
- }
61
+ # NOTE static=True only for leaf attributes that are not valid JAX types
62
+ # (ie. jax.Array cannot be static) and that we do not expect to change
63
+ # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
64
+ derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
65
+ loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
66
+ initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
67
+ obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
114
68
 
115
- self.derivative_keys = derivative_keys
69
+ def __post_init__(self):
70
+ if self.loss_weights is None:
71
+ self.loss_weights = LossWeightsODE()
116
72
 
117
- if initial_condition is None:
73
+ if self.derivative_keys is None:
74
+ # be default we only take gradient wrt nn_params
75
+ self.derivative_keys = DerivativeKeysODE()
76
+ if self.initial_condition is None:
118
77
  warnings.warn(
119
78
  "Initial condition wasn't provided. Be sure to cover for that"
120
79
  "case (e.g by. hardcoding it into the PINN output)."
121
80
  )
122
81
  else:
123
- if not isinstance(initial_condition, tuple) or len(initial_condition) != 2:
82
+ if (
83
+ not isinstance(self.initial_condition, tuple)
84
+ or len(self.initial_condition) != 2
85
+ ):
124
86
  raise ValueError(
125
- f"Initial condition should be a tuple of len 2 with (t0, u0), {initial_condition} was passed."
87
+ "Initial condition should be a tuple of len 2 with (t0, u0), "
88
+ f"{self.initial_condition} was passed."
126
89
  )
127
- self.initial_condition = initial_condition
128
- self.loss_weights = loss_weights
129
- self.obs_slice = obs_slice
90
+
130
91
  if self.obs_slice is None:
131
92
  self.obs_slice = jnp.s_[...]
132
93
 
133
- for k in _LOSS_WEIGHT_KEYS_ODE:
134
- if k not in self.loss_weights.keys():
135
- self.loss_weights[k] = 0
94
+ if self.loss_weights is None:
95
+ self.loss_weights = LossWeightsODE()
96
+
97
+ @abc.abstractmethod
98
+ def evaluate(
99
+ self: eqx.Module, params: Params, batch: ODEBatch
100
+ ) -> tuple[Float, dict]:
101
+ raise NotImplementedError
102
+
103
+
104
+ class LossODE(_LossODEAbstract):
105
+ r"""Loss object for an ordinary differential equation
106
+
107
+ $$
108
+ \mathcal{N}[u](t) = 0, \forall t \in I
109
+ $$
110
+
111
+ where $\mathcal{N}[\cdot]$ is a differential operator and the
112
+ initial condition is $u(t_0)=u_0$.
113
+
114
+
115
+ Parameters
116
+ ----------
117
+ loss_weights : LossWeightsODE, default=None
118
+ The loss weights for the differents term : dynamic loss,
119
+ initial condition and eventually observations if any. All fields are
120
+ set to 1.0 by default.
121
+ derivative_keys : DerivativeKeysODE, default=None
122
+ Specify which field of `params` should be differentiated for each
123
+ composant of the total loss. Particularily useful for inverse problems.
124
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
125
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
126
+ is `"nn_params"` for each composant of the loss.
127
+ initial_condition : tuple, default=None
128
+ tuple of length 2 with initial condition $(t_0, u_0)$.
129
+ obs_slice Slice, default=None
130
+ Slice object specifying the begininning/ending
131
+ slice of u output(s) that is observed. This is useful for
132
+ multidimensional PINN, with partially observed outputs.
133
+ Default is None (whole output is observed).
134
+ u : eqx.Module
135
+ the PINN
136
+ dynamic_loss : DynamicLoss
137
+ the ODE dynamic part of the loss, basically the differential
138
+ operator $\mathcal{N}[u](t)$. Should implement a method
139
+ `dynamic_loss.evaluate(t, u, params)`.
140
+ Can be None in order to access only some part of the evaluate call.
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ if initial condition is not a tuple.
146
+ """
147
+
148
+ # NOTE static=True only for leaf attributes that are not valid JAX types
149
+ # (ie. jax.Array cannot be static) and that we do not expect to change
150
+ u: eqx.Module
151
+ dynamic_loss: DynamicLoss | None
152
+
153
+ vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
154
+
155
+ def __post_init__(self):
156
+ super().__post_init__() # because __init__ or __post_init__ of Base
157
+ # class is not automatically called
158
+
159
+ self.vmap_in_axes = (0,)
136
160
 
137
161
  def __call__(self, *args, **kwargs):
138
162
  return self.evaluate(*args, **kwargs)
139
163
 
140
- def evaluate(self, params, batch):
164
+ def evaluate(
165
+ self, params: Params, batch: ODEBatch
166
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
141
167
  """
142
168
  Evaluate the loss function at a batch of points for given parameters.
143
169
 
@@ -145,21 +171,14 @@ class LossODE:
145
171
  Parameters
146
172
  ---------
147
173
  params
148
- The dictionary of parameters of the model.
149
- Typically, it is a dictionary of
150
- dictionaries: `eq_params` and `nn_params``, respectively the
151
- differential equation parameters and the neural network parameter
174
+ Parameters at which the loss is evaluated
152
175
  batch
153
- A ODEBatch object.
154
- Such a named tuple is composed of a batch of time points
155
- at which to evaluate an optional additional batch of parameters (eg. for
156
- metamodeling) and an optional additional batch of observed
157
- inputs/outputs/parameters
176
+ Composed of a batch of time points
177
+ at which to evaluate the differential operator. An optional additional batch of parameters (eg. for metamodeling) and an optional additional batch of observed inputs/outputs/parameters can
178
+ be supplied.
158
179
  """
159
180
  temporal_batch = batch.temporal_batch
160
181
 
161
- vmap_in_axes_t = (0,)
162
-
163
182
  # Retrieve the optional eq_params_batch
164
183
  # and update eq_params with the latter
165
184
  # and update vmap_in_axes
@@ -170,21 +189,19 @@ class LossODE:
170
189
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
171
190
 
172
191
  ## dynamic part
173
- params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
174
192
  if self.dynamic_loss is not None:
175
193
  mse_dyn_loss = dynamic_loss_apply(
176
194
  self.dynamic_loss.evaluate,
177
195
  self.u,
178
196
  (temporal_batch,),
179
- params_,
180
- vmap_in_axes_t + vmap_in_axes_params,
181
- self.loss_weights["dyn_loss"],
197
+ _set_derivatives(params, self.derivative_keys.dyn_loss),
198
+ self.vmap_in_axes + vmap_in_axes_params,
199
+ self.loss_weights.dyn_loss,
182
200
  )
183
201
  else:
184
202
  mse_dyn_loss = jnp.array(0.0)
185
203
 
186
204
  # initial condition
187
- params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
188
205
  if self.initial_condition is not None:
189
206
  vmap_in_axes = (None,) + vmap_in_axes_params
190
207
  if not jax.tree_util.tree_leaves(vmap_in_axes):
@@ -193,12 +210,24 @@ class LossODE:
193
210
  v_u = self.u
194
211
  else:
195
212
  v_u = vmap(self.u, (None,) + vmap_in_axes_params)
196
- t0, u0 = self.initial_condition
213
+ t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
197
214
  t0 = jnp.array(t0)
198
215
  u0 = jnp.array(u0)
199
216
  mse_initial_condition = jnp.mean(
200
- self.loss_weights["initial_condition"]
201
- * jnp.sum((v_u(t0, params_) - u0) ** 2, axis=-1)
217
+ self.loss_weights.initial_condition
218
+ * jnp.sum(
219
+ (
220
+ v_u(
221
+ t0,
222
+ _set_derivatives(
223
+ params, self.derivative_keys.initial_condition
224
+ ),
225
+ )
226
+ - u0
227
+ )
228
+ ** 2,
229
+ axis=-1,
230
+ )
202
231
  )
203
232
  else:
204
233
  mse_initial_condition = jnp.array(0.0)
@@ -208,14 +237,13 @@ class LossODE:
208
237
  params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
209
238
 
210
239
  # MSE loss wrt to an observed batch
211
- params_ = _set_derivatives(params, "observations", self.derivative_keys)
212
240
  mse_observation_loss = observations_loss_apply(
213
241
  self.u,
214
242
  (batch.obs_batch_dict["pinn_in"],),
215
- params_,
216
- vmap_in_axes_t + vmap_in_axes_params,
243
+ _set_derivatives(params, self.derivative_keys.observations),
244
+ self.vmap_in_axes + vmap_in_axes_params,
217
245
  batch.obs_batch_dict["val"],
218
- self.loss_weights["observations"],
246
+ self.loss_weights.observations,
219
247
  self.obs_slice,
220
248
  )
221
249
  else:
@@ -231,193 +259,174 @@ class LossODE:
231
259
  }
232
260
  )
233
261
 
234
- def tree_flatten(self):
235
- children = (self.initial_condition, self.loss_weights)
236
- aux_data = {
237
- "u": self.u,
238
- "dynamic_loss": self.dynamic_loss,
239
- "obs_slice": self.obs_slice,
240
- "derivative_keys": self.derivative_keys,
241
- }
242
- return (children, aux_data)
243
-
244
- @classmethod
245
- def tree_unflatten(cls, aux_data, children):
246
- (initial_condition, loss_weights) = children
247
- loss_ode = cls(
248
- loss_weights=loss_weights,
249
- initial_condition=initial_condition,
250
- **aux_data,
251
- )
252
- return loss_ode
253
-
254
262
 
255
- @register_pytree_node_class
256
- class SystemLossODE:
257
- """
263
+ class SystemLossODE(eqx.Module):
264
+ r"""
258
265
  Class to implement a system of ODEs.
259
266
  The goal is to give maximum freedom to the user. The class is created with
260
- a dict of dynamic loss and a dict of initial conditions. When then iterate
261
- over the dynamic losses that compose the system. All the PINNs with all the
262
- parameter dictionaries are passed as arguments to each dynamic loss
263
- evaluate functions; it is inside the dynamic loss that specification are
264
- performed.
267
+ a dict of dynamic loss and a dict of initial conditions. Then, it iterates
268
+ over the dynamic losses that compose the system. All PINNs are passed as
269
+ arguments to each dynamic loss evaluate functions, along with all the
270
+ parameter dictionaries. All specification is left to the responsability
271
+ of the user, inside the dynamic loss.
265
272
 
266
273
  **Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
267
274
  Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
268
275
  solution.
269
276
 
270
- **Note:** SystemLossODE is jittable. Hence it implements the tree_flatten() and
271
- tree_unflatten methods.
277
+ Parameters
278
+ ----------
279
+ u_dict : Dict[str, eqx.Module]
280
+ dict of PINNs
281
+ loss_weights : LossWeightsODEDict
282
+ A dictionary of LossWeightsODE
283
+ derivative_keys_dict : Dict[str, DerivativeKeysODE], default=None
284
+ A dictionnary of DerivativeKeysODE specifying what field of `params`
285
+ should be used during gradient computations for each of the terms of
286
+ the total loss, for each of the loss in the system. Default is
287
+ `"nn_params`" everywhere.
288
+ initial_condition_dict : Dict[str, tuple], default=None
289
+ dict of tuple of length 2 with initial condition $(t_0, u_0)$
290
+ Must share the keys of `u_dict`. Default is None. No initial
291
+ condition is permitted when the initial condition is hardcoded in
292
+ the PINN architecture for example
293
+ dynamic_loss_dict : Dict[str, ODE]
294
+ dict of dynamic part of the loss, basically the differential
295
+ operator $\mathcal{N}[u](t)$. Should implement a method
296
+ `dynamic_loss.evaluate(t, u, params)`
297
+ obs_slice_dict : Dict[str, Slice]
298
+ dict of obs_slice, with keys from `u_dict` to designate the
299
+ output(s) channels that are observed, for each
300
+ PINNs. Default is None. But if a value is given, all the entries of
301
+ `u_dict` must be represented here with default value `jnp.s_[...]`
302
+ if no particular slice is to be given.
303
+
304
+ Raises
305
+ ------
306
+ ValueError
307
+ if initial condition is not a dict of tuple.
308
+ ValueError
309
+ if the dictionaries that should share the keys of u_dict do not.
272
310
  """
273
311
 
274
- def __init__(
275
- self,
276
- u_dict,
277
- loss_weights,
278
- dynamic_loss_dict,
279
- derivative_keys_dict=None,
280
- initial_condition_dict=None,
281
- obs_slice_dict=None,
282
- ):
283
- r"""
284
- Parameters
285
- ----------
286
- u_dict
287
- dict of PINNs
288
- loss_weights
289
- A dictionary of dictionaries with values used to
290
- ponderate each term in the loss
291
- function. Valid keys in the first dictionary are `dyn_loss`,
292
- `initial_condition` and `observations`. The keys of the nested
293
- dictionaries must share the keys of `u_dict`. Note that the values
294
- at the leaf level can have jnp.arrays with the same dimension of
295
- `u` which then ponderates each output of `u`
296
- derivative_keys_dict
297
- A dict of derivative keys as defined in LossODE. The key of this
298
- dict must be that of `dynamic_loss_dict` at least and specify how
299
- to compute gradient for the `dyn_loss` loss term at least (see the
300
- check at the beginning of the present `__init__` function.
301
- Other keys of this dict might be that of `u_dict` to specify how to
302
- compute gradients for all the different constraints. If those keys
303
- are not specified then the default behaviour for `derivative_keys`
304
- of LossODE is used
305
- initial_condition_dict
306
- dict of tuple of length 2 with initial condition :math:`(t_0, u_0)`
307
- Must share the keys of `u_dict`. Default is None. No initial
308
- condition is permitted when the initial condition is hardcoded in
309
- the PINN architecture for example
310
- dynamic_loss_dict
311
- dict of dynamic part of the loss, basically the differential
312
- operator :math:`\mathcal{N}[u](t)`. Should implement a method
313
- `dynamic_loss.evaluate(t, u, params)`
314
- obs_slice_dict
315
- dict of obs_slice, with keys from `u_dict` to designate the
316
- output(s) channels that are forced to observed values, for each
317
- PINNs. Default is None. But if a value is given, all the entries of
318
- `u_dict` must be represented here with default value `jnp.s_[...]`
319
- if no particular slice is to be given
320
-
321
- Raises
322
- ------
323
- ValueError
324
- if initial condition is not a dict of tuple
325
- ValueError
326
- if the dictionaries that should share the keys of u_dict do not
327
- """
328
-
312
+ # NOTE static=True only for leaf attributes that are not valid JAX types
313
+ # (ie. jax.Array cannot be static) and that we do not expect to change
314
+ u_dict: Dict[str, eqx.Module]
315
+ dynamic_loss_dict: Dict[str, ODE]
316
+ derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
317
+ kw_only=True, default=None
318
+ )
319
+ initial_condition_dict: Dict[str, tuple] | None = eqx.field(
320
+ kw_only=True, default=None
321
+ )
322
+
323
+ obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
324
+ kw_only=True, default=None, static=True
325
+ ) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
326
+ # we do not expect it to change with put a static=True here. But note that
327
+ # this is the only static for all the SystemLossODE attribute, since all
328
+ # other are composed of more complex structures ("non-leaf")
329
+
330
+ # For the user loss_weights are passed as a LossWeightsODEDict (with internal
331
+ # dictionary having keys in u_dict and / or dynamic_loss_dict)
332
+ loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
333
+ kw_only=True, default=None
334
+ )
335
+ u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
336
+ derivative_keys_dyn_loss_dict: Dict[str, DerivativeKeysODE] = eqx.field(init=False)
337
+
338
+ u_dict_with_none: Dict[str, None] = eqx.field(init=False)
339
+ # internally the loss weights are handled with a dictionary
340
+ _loss_weights: Dict[str, dict] = eqx.field(init=False)
341
+
342
+ def __post_init__(self, loss_weights):
329
343
  # a dictionary that will be useful at different places
330
- self.u_dict_with_none = {k: None for k in u_dict.keys()}
331
- if initial_condition_dict is None:
344
+ self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
345
+ if self.initial_condition_dict is None:
332
346
  self.initial_condition_dict = self.u_dict_with_none
333
347
  else:
334
- self.initial_condition_dict = initial_condition_dict
335
- if u_dict.keys() != initial_condition_dict.keys():
348
+ if self.u_dict.keys() != self.initial_condition_dict.keys():
336
349
  raise ValueError(
337
350
  "initial_condition_dict should have same keys as u_dict"
338
351
  )
339
- if obs_slice_dict is None:
340
- self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
352
+ if self.obs_slice_dict is None:
353
+ self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
341
354
  else:
342
- self.obs_slice_dict = obs_slice_dict
343
- if u_dict.keys() != obs_slice_dict.keys():
355
+ if self.u_dict.keys() != self.obs_slice_dict.keys():
344
356
  raise ValueError("obs_slice_dict should have same keys as u_dict")
345
357
 
346
- if derivative_keys_dict is None:
358
+ if self.derivative_keys_dict is None:
347
359
  self.derivative_keys_dict = {
348
360
  k: None
349
- for k in set(list(dynamic_loss_dict.keys()) + list(u_dict.keys()))
361
+ for k in set(
362
+ list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
363
+ )
350
364
  }
351
365
  # set() because we can have duplicate entries and in this case we
352
366
  # say it corresponds to the same derivative_keys_dict entry
353
- else:
354
- self.derivative_keys_dict = derivative_keys_dict
367
+ # we need both because the constraints (all but dyn_loss) will be
368
+ # done by iterating on u_dict while the dyn_loss will be by
369
+ # iterating on dynamic_loss_dict. So each time we will require dome
370
+ # derivative_keys_dict
355
371
 
356
372
  # but then if the user did not provide anything, we must at least have
357
373
  # a default value for the dynamic_loss_dict keys entries in
358
374
  # self.derivative_keys_dict since the computation of dynamic losses is
359
375
  # made without create a lossODE object that would provide the
360
376
  # default values
361
- for k in dynamic_loss_dict.keys():
377
+ for k in self.dynamic_loss_dict.keys():
362
378
  if self.derivative_keys_dict[k] is None:
363
- self.derivative_keys_dict[k] = {"dyn_loss": ["nn_params"]}
379
+ self.derivative_keys_dict[k] = DerivativeKeysODE()
364
380
 
365
- self.dynamic_loss_dict = dynamic_loss_dict
366
- self.u_dict = u_dict
367
-
368
- self.loss_weights = loss_weights # We call the setter
369
- # note that self.initial_condition_dict must be
370
- # initialized beforehand
381
+ self._loss_weights = self.set_loss_weights(loss_weights)
371
382
 
372
383
  # The constaints on the solutions will be implemented by reusing a
373
384
  # LossODE class without dynamic loss term
374
385
  self.u_constraints_dict = {}
375
386
  for i in self.u_dict.keys():
376
387
  self.u_constraints_dict[i] = LossODE(
377
- u=u_dict[i],
378
- loss_weights={
379
- "dyn_loss": 0.0,
380
- "initial_condition": 1.0,
381
- "observations": 1.0,
382
- },
388
+ u=self.u_dict[i],
389
+ loss_weights=LossWeightsODE(
390
+ dyn_loss=0.0,
391
+ initial_condition=1.0,
392
+ observations=1.0,
393
+ ),
383
394
  dynamic_loss=None,
384
395
  derivative_keys=self.derivative_keys_dict[i],
385
396
  initial_condition=self.initial_condition_dict[i],
386
397
  obs_slice=self.obs_slice_dict[i],
387
398
  )
388
399
 
389
- # for convenience in the tree_map of evaluate,
390
- # we separate the two derivative keys dict
400
+ # for convenience in the tree_map of evaluate
391
401
  self.derivative_keys_dyn_loss_dict = {
392
402
  k: self.derivative_keys_dict[k]
393
- for k in self.dynamic_loss_dict.keys() & self.derivative_keys_dict.keys()
394
- }
395
- self.derivative_keys_u_dict = {
396
- k: self.derivative_keys_dict[k]
397
- for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
403
+ for k in self.dynamic_loss_dict.keys() # & self.derivative_keys_dict.keys()
404
+ # comment because intersection is neceserily fulfilled right?
398
405
  }
399
406
 
400
- @property
401
- def loss_weights(self):
402
- return self._loss_weights
403
-
404
- @loss_weights.setter
405
- def loss_weights(self, value):
406
- self._loss_weights = {}
407
- for k, v in value.items():
407
+ def set_loss_weights(self, loss_weights_init):
408
+ """
409
+ This rather complex function enables the user to specify a simple
410
+ loss_weights=LossWeightsODEDict(dyn_loss=1., initial_condition=Tmax)
411
+ for ponderating values being applied to all the equations of the
412
+ system... So all the transformations are handled here
413
+ """
414
+ _loss_weights = {}
415
+ for k in fields(loss_weights_init):
416
+ v = getattr(loss_weights_init, k.name)
408
417
  if isinstance(v, dict):
409
- for kk, vv in v.items():
418
+ for vv in v.values():
410
419
  if not isinstance(vv, (int, float)) and not (
411
- isinstance(vv, jnp.ndarray)
420
+ isinstance(vv, Array)
412
421
  and ((vv.shape == (1,) or len(vv.shape) == 0))
413
422
  ):
414
423
  # TODO improve that
415
424
  raise ValueError(
416
425
  f"loss values cannot be vectorial here, got {vv}"
417
426
  )
418
- if k == "dyn_loss":
427
+ if k.name == "dyn_loss":
419
428
  if v.keys() == self.dynamic_loss_dict.keys():
420
- self._loss_weights[k] = v
429
+ _loss_weights[k.name] = v
421
430
  else:
422
431
  raise ValueError(
423
432
  "Keys in nested dictionary of loss_weights"
@@ -425,48 +434,41 @@ class SystemLossODE:
425
434
  )
426
435
  else:
427
436
  if v.keys() == self.u_dict.keys():
428
- self._loss_weights[k] = v
437
+ _loss_weights[k.name] = v
429
438
  else:
430
439
  raise ValueError(
431
440
  "Keys in nested dictionary of loss_weights"
432
441
  " do not match u_dict keys"
433
442
  )
443
+ elif v is None:
444
+ _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
434
445
  else:
435
446
  if not isinstance(v, (int, float)) and not (
436
- isinstance(v, jnp.ndarray)
437
- and ((v.shape == (1,) or len(v.shape) == 0))
447
+ isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
438
448
  ):
439
449
  # TODO improve that
440
450
  raise ValueError(f"loss values cannot be vectorial here, got {v}")
441
- if k == "dyn_loss":
442
- self._loss_weights[k] = {
451
+ if k.name == "dyn_loss":
452
+ _loss_weights[k.name] = {
443
453
  kk: v for kk in self.dynamic_loss_dict.keys()
444
454
  }
445
455
  else:
446
- self._loss_weights[k] = {kk: v for kk in self.u_dict.keys()}
447
- if all(v is None for k, v in self.initial_condition_dict.items()):
448
- self._loss_weights["initial_condition"] = {k: 0 for k in self.u_dict.keys()}
449
- if "observations" not in value.keys():
450
- self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
456
+ _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
457
+
458
+ return _loss_weights
451
459
 
452
460
  def __call__(self, *args, **kwargs):
453
461
  return self.evaluate(*args, **kwargs)
454
462
 
455
- def evaluate(self, params_dict, batch):
463
+ def evaluate(self, params_dict: ParamsDict, batch: ODEBatch) -> Float[Array, "1"]:
456
464
  """
457
465
  Evaluate the loss function at a batch of points for given parameters.
458
466
 
459
467
 
460
468
  Parameters
461
469
  ---------
462
- params_dict
463
- A dictionary of dictionaries of parameters of the model.
464
- Typically, it is a dictionary of dictionaries of
465
- dictionaries: `eq_params` and `nn_params``, respectively the
466
- differential equation parameters and the neural network parameter.
467
- Note that params_dict["nn_params"] need not be a dictionary anymore
468
- but can directly be the parameters. It is useful when working with
469
- neural networks sharing the same parameters
470
+ params
471
+ A ParamsDict object
470
472
  batch
471
473
  A ODEBatch object.
472
474
  Such a named tuple is composed of a batch of time points
@@ -475,10 +477,10 @@ class SystemLossODE:
475
477
  inputs/outputs/parameters
476
478
  """
477
479
  if (
478
- isinstance(params_dict["nn_params"], dict)
479
- and self.u_dict.keys() != params_dict["nn_params"].keys()
480
+ isinstance(params_dict.nn_params, dict)
481
+ and self.u_dict.keys() != params_dict.nn_params.keys()
480
482
  ):
481
- raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
483
+ raise ValueError("u_dict and params_dict.nn_params should have same keys ")
482
484
 
483
485
  temporal_batch = batch.temporal_batch
484
486
 
@@ -489,7 +491,7 @@ class SystemLossODE:
489
491
  # and update vmap_in_axes
490
492
  if batch.param_batch_dict is not None:
491
493
  # update params with the batches of generated params
492
- params_dict = _update_eq_params_dict(params_dict, batch.param_batch_dict)
494
+ params = _update_eq_params_dict(params, batch.param_batch_dict)
493
495
 
494
496
  vmap_in_axes_params = _get_vmap_in_axes_params(
495
497
  batch.param_batch_dict, params_dict
@@ -497,12 +499,11 @@ class SystemLossODE:
497
499
 
498
500
  def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
499
501
  """This function is used in tree_map"""
500
- params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
501
502
  return dynamic_loss_apply(
502
503
  dyn_loss.evaluate,
503
504
  self.u_dict,
504
505
  (temporal_batch,),
505
- params_dict_,
506
+ _set_derivatives(params_dict, derivative_key.dyn_loss),
506
507
  vmap_in_axes_t + vmap_in_axes_params,
507
508
  loss_weight,
508
509
  u_type=PINN,
@@ -513,6 +514,11 @@ class SystemLossODE:
513
514
  self.dynamic_loss_dict,
514
515
  self.derivative_keys_dyn_loss_dict,
515
516
  self._loss_weights["dyn_loss"],
517
+ is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
518
+ # where plain (unregister pytree) node classes, we could not traverse
519
+ # this level. Now that dynamic losses are eqx.Module they can be
520
+ # traversed by tree map recursion. Hence we need to specify to that
521
+ # we want to stop at this level
516
522
  )
517
523
  mse_dyn_loss = jax.tree_util.tree_reduce(
518
524
  lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
@@ -527,7 +533,8 @@ class SystemLossODE:
527
533
 
528
534
  # we need to do the following for the tree_mapping to work
529
535
  if batch.obs_batch_dict is None:
530
- batch = batch._replace(obs_batch_dict=self.u_dict_with_none)
536
+ batch = append_obs_batch(batch, self.u_dict_with_none)
537
+
531
538
  total_loss, res_dict = constraints_system_loss_apply(
532
539
  self.u_constraints_dict,
533
540
  batch,
@@ -540,27 +547,3 @@ class SystemLossODE:
540
547
  total_loss += mse_dyn_loss
541
548
  res_dict["dyn_loss"] += mse_dyn_loss
542
549
  return total_loss, res_dict
543
-
544
- def tree_flatten(self):
545
- children = (
546
- self.initial_condition_dict,
547
- self._loss_weights,
548
- )
549
- aux_data = {
550
- "u_dict": self.u_dict,
551
- "dynamic_loss_dict": self.dynamic_loss_dict,
552
- "derivative_keys_dict": self.derivative_keys_dict,
553
- "obs_slice_dict": self.obs_slice_dict,
554
- }
555
- return (children, aux_data)
556
-
557
- @classmethod
558
- def tree_unflatten(cls, aux_data, children):
559
- (initial_condition_dict, loss_weights) = children
560
- loss_ode = cls(
561
- loss_weights=loss_weights,
562
- initial_condition_dict=initial_condition_dict,
563
- **aux_data,
564
- )
565
-
566
- return loss_ode