jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py CHANGED
@@ -1,24 +1,23 @@
1
- # pylint: disable=unsubscriptable-object, no-member
2
1
  """
3
2
  Main module to implement a ODE loss in jinns
4
3
  """
4
+
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from dataclasses import InitVar, fields
10
- from typing import TYPE_CHECKING, Dict
9
+ from dataclasses import InitVar
10
+ from typing import TYPE_CHECKING, TypedDict, Callable
11
+ from types import EllipsisType
11
12
  import abc
12
13
  import warnings
13
14
  import jax
14
15
  import jax.numpy as jnp
15
16
  from jax import vmap
16
17
  import equinox as eqx
17
- from jaxtyping import Float, Array, Int
18
- from jinns.data._DataGenerators import append_obs_batch
18
+ from jaxtyping import Float, Array
19
19
  from jinns.loss._loss_utils import (
20
20
  dynamic_loss_apply,
21
- constraints_system_loss_apply,
22
21
  observations_loss_apply,
23
22
  )
24
23
  from jinns.parameters._params import (
@@ -26,37 +25,49 @@ from jinns.parameters._params import (
26
25
  _update_eq_params_dict,
27
26
  )
28
27
  from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
29
- from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
30
- from jinns.loss._DynamicLossAbstract import ODE
31
- from jinns.nn._pinn import PINN
28
+ from jinns.loss._loss_weights import LossWeightsODE
29
+ from jinns.loss._abstract_loss import AbstractLoss
30
+ from jinns.loss._loss_components import ODEComponents
31
+ from jinns.parameters._params import Params
32
32
 
33
33
  if TYPE_CHECKING:
34
- from jinns.utils._types import *
34
+ # imports only used in type hints
35
+ from jinns.data._Batchs import ODEBatch
36
+ from jinns.nn._abstract_pinn import AbstractPINN
37
+ from jinns.loss import ODE
38
+
39
+ class LossDictODE(TypedDict):
40
+ dyn_loss: Float[Array, " "]
41
+ initial_condition: Float[Array, " "]
42
+ observations: Float[Array, " "]
35
43
 
36
44
 
37
- class _LossODEAbstract(eqx.Module):
45
+ class _LossODEAbstract(AbstractLoss):
38
46
  """
39
47
  Parameters
40
48
  ----------
41
49
 
42
50
  loss_weights : LossWeightsODE, default=None
43
51
  The loss weights for the differents term : dynamic loss,
44
- initial condition and eventually observations if any. All fields are
45
- set to 1.0 by default.
52
+ initial condition and eventually observations if any.
53
+ Can be updated according to a specific algorithm. See
54
+ `update_weight_method`
55
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
56
+ Default is None meaning no update for loss weights. Otherwise a string
46
57
  derivative_keys : DerivativeKeysODE, default=None
47
58
  Specify which field of `params` should be differentiated for each
48
59
  composant of the total loss. Particularily useful for inverse problems.
49
60
  Fields can be "nn_params", "eq_params" or "both". Those that should not
50
61
  be updated will have a `jax.lax.stop_gradient` called on them. Default
51
62
  is `"nn_params"` for each composant of the loss.
52
- initial_condition : tuple, default=None
63
+ initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
53
64
  tuple of length 2 with initial condition $(t_0, u_0)$.
54
- obs_slice : Slice, default=None
65
+ obs_slice : EllipsisType | slice | None, default=None
55
66
  Slice object specifying the begininning/ending
56
67
  slice of u output(s) that is observed. This is useful for
57
68
  multidimensional PINN, with partially observed outputs.
58
69
  Default is None (whole output is observed).
59
- params : InitVar[Params], default=None
70
+ params : InitVar[Params[Array]], default=None
60
71
  The main Params object of the problem needed to instanciate the
61
72
  DerivativeKeysODE if the latter is not specified.
62
73
  """
@@ -66,24 +77,27 @@ class _LossODEAbstract(eqx.Module):
66
77
  # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
67
78
  derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
68
79
  loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
69
- initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
70
- obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
80
+ initial_condition: (
81
+ tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
82
+ ) = eqx.field(kw_only=True, default=None)
83
+ obs_slice: EllipsisType | slice | None = eqx.field(
84
+ kw_only=True, default=None, static=True
85
+ )
71
86
 
72
- params: InitVar[Params] = eqx.field(default=None, kw_only=True)
87
+ params: InitVar[Params[Array]] = eqx.field(default=None, kw_only=True)
73
88
 
74
- def __post_init__(self, params=None):
89
+ def __post_init__(self, params: Params[Array] | None = None):
75
90
  if self.loss_weights is None:
76
91
  self.loss_weights = LossWeightsODE()
77
92
 
78
93
  if self.derivative_keys is None:
79
- try:
80
- # be default we only take gradient wrt nn_params
81
- self.derivative_keys = DerivativeKeysODE(params=params)
82
- except ValueError as exc:
94
+ # by default we only take gradient wrt nn_params
95
+ if params is None:
83
96
  raise ValueError(
84
97
  "Problem at self.derivative_keys initialization "
85
98
  f"received {self.derivative_keys=} and {params=}"
86
- ) from exc
99
+ )
100
+ self.derivative_keys = DerivativeKeysODE(params=params)
87
101
  if self.initial_condition is None:
88
102
  warnings.warn(
89
103
  "Initial condition wasn't provided. Be sure to cover for that"
@@ -98,6 +112,21 @@ class _LossODEAbstract(eqx.Module):
98
112
  "Initial condition should be a tuple of len 2 with (t0, u0), "
99
113
  f"{self.initial_condition} was passed."
100
114
  )
115
+ # some checks/reshaping for t0
116
+ t0, u0 = self.initial_condition
117
+ if isinstance(t0, Array):
118
+ if not t0.shape: # e.g. user input: jnp.array(0.)
119
+ t0 = jnp.array([t0])
120
+ elif t0.shape != (1,):
121
+ raise ValueError(
122
+ f"Wrong t0 input (self.initial_condition[0]) It should be"
123
+ f"a float or an array of shape (1,). Got shape: {t0.shape}"
124
+ )
125
+ if isinstance(t0, float): # e.g. user input: 0.
126
+ t0 = jnp.array([t0])
127
+ if isinstance(t0, int): # e.g. user input: 0
128
+ t0 = jnp.array([float(t0)])
129
+ self.initial_condition = (t0, u0)
101
130
 
102
131
  if self.obs_slice is None:
103
132
  self.obs_slice = jnp.s_[...]
@@ -105,10 +134,14 @@ class _LossODEAbstract(eqx.Module):
105
134
  if self.loss_weights is None:
106
135
  self.loss_weights = LossWeightsODE()
107
136
 
137
+ @abc.abstractmethod
138
+ def __call__(self, *_, **__):
139
+ pass
140
+
108
141
  @abc.abstractmethod
109
142
  def evaluate(
110
- self: eqx.Module, params: Params, batch: ODEBatch
111
- ) -> tuple[Float, dict]:
143
+ self: eqx.Module, params: Params[Array], batch: ODEBatch
144
+ ) -> tuple[Float[Array, " "], LossDictODE]:
112
145
  raise NotImplementedError
113
146
 
114
147
 
@@ -127,27 +160,30 @@ class LossODE(_LossODEAbstract):
127
160
  ----------
128
161
  loss_weights : LossWeightsODE, default=None
129
162
  The loss weights for the differents term : dynamic loss,
130
- initial condition and eventually observations if any. All fields are
131
- set to 1.0 by default.
163
+ initial condition and eventually observations if any.
164
+ Can be updated according to a specific algorithm. See
165
+ `update_weight_method`
166
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
167
+ Default is None meaning no update for loss weights. Otherwise a string
132
168
  derivative_keys : DerivativeKeysODE, default=None
133
169
  Specify which field of `params` should be differentiated for each
134
170
  composant of the total loss. Particularily useful for inverse problems.
135
171
  Fields can be "nn_params", "eq_params" or "both". Those that should not
136
172
  be updated will have a `jax.lax.stop_gradient` called on them. Default
137
173
  is `"nn_params"` for each composant of the loss.
138
- initial_condition : tuple, default=None
174
+ initial_condition : tuple[float | Float[Array, " 1"]], default=None
139
175
  tuple of length 2 with initial condition $(t_0, u_0)$.
140
- obs_slice Slice, default=None
176
+ obs_slice : EllipsisType | slice | None, default=None
141
177
  Slice object specifying the begininning/ending
142
178
  slice of u output(s) that is observed. This is useful for
143
179
  multidimensional PINN, with partially observed outputs.
144
180
  Default is None (whole output is observed).
145
- params : InitVar[Params], default=None
181
+ params : InitVar[Params[Array]], default=None
146
182
  The main Params object of the problem needed to instanciate the
147
183
  DerivativeKeysODE if the latter is not specified.
148
184
  u : eqx.Module
149
185
  the PINN
150
- dynamic_loss : DynamicLoss
186
+ dynamic_loss : ODE
151
187
  the ODE dynamic part of the loss, basically the differential
152
188
  operator $\mathcal{N}[u](t)$. Should implement a method
153
189
  `dynamic_loss.evaluate(t, u, params)`.
@@ -161,12 +197,12 @@ class LossODE(_LossODEAbstract):
161
197
 
162
198
  # NOTE static=True only for leaf attributes that are not valid JAX types
163
199
  # (ie. jax.Array cannot be static) and that we do not expect to change
164
- u: eqx.Module
165
- dynamic_loss: DynamicLoss | None
200
+ u: AbstractPINN
201
+ dynamic_loss: ODE | None
166
202
 
167
- vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
203
+ vmap_in_axes: tuple[int] = eqx.field(init=False, static=True)
168
204
 
169
- def __post_init__(self, params=None):
205
+ def __post_init__(self, params: Params[Array] | None = None):
170
206
  super().__post_init__(
171
207
  params=params
172
208
  ) # because __init__ or __post_init__ of Base
@@ -177,21 +213,26 @@ class LossODE(_LossODEAbstract):
177
213
  def __call__(self, *args, **kwargs):
178
214
  return self.evaluate(*args, **kwargs)
179
215
 
180
- def evaluate(
181
- self, params: Params, batch: ODEBatch
182
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
216
+ def evaluate_by_terms(
217
+ self, params: Params[Array], batch: ODEBatch
218
+ ) -> tuple[
219
+ ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
220
+ ]:
183
221
  """
184
222
  Evaluate the loss function at a batch of points for given parameters.
185
223
 
224
+ We retrieve two PyTrees with loss values and gradients for each term
186
225
 
187
226
  Parameters
188
227
  ---------
189
228
  params
190
229
  Parameters at which the loss is evaluated
191
230
  batch
192
- Composed of a batch of time points
193
- at which to evaluate the differential operator. An optional additional batch of parameters (eg. for metamodeling) and an optional additional batch of observed inputs/outputs/parameters can
194
- be supplied.
231
+ Composed of a batch of points in the
232
+ domain, a batch of points in the domain
233
+ border and an optional additional batch of parameters (eg. for
234
+ metamodeling) and an optional additional batch of observed
235
+ inputs/outputs/parameters
195
236
  """
196
237
  temporal_batch = batch.temporal_batch
197
238
 
@@ -206,16 +247,15 @@ class LossODE(_LossODEAbstract):
206
247
 
207
248
  ## dynamic part
208
249
  if self.dynamic_loss is not None:
209
- mse_dyn_loss = dynamic_loss_apply(
210
- self.dynamic_loss.evaluate,
250
+ dyn_loss_fun = lambda p: dynamic_loss_apply(
251
+ self.dynamic_loss.evaluate, # type: ignore
211
252
  self.u,
212
253
  temporal_batch,
213
- _set_derivatives(params, self.derivative_keys.dyn_loss),
254
+ _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
214
255
  self.vmap_in_axes + vmap_in_axes_params,
215
- self.loss_weights.dyn_loss,
216
256
  )
217
257
  else:
218
- mse_dyn_loss = jnp.array(0.0)
258
+ dyn_loss_fun = None
219
259
 
220
260
  # initial condition
221
261
  if self.initial_condition is not None:
@@ -226,18 +266,14 @@ class LossODE(_LossODEAbstract):
226
266
  v_u = self.u
227
267
  else:
228
268
  v_u = vmap(self.u, (None,) + vmap_in_axes_params)
229
- t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
230
- t0 = jnp.array([t0])
269
+ t0, u0 = self.initial_condition
231
270
  u0 = jnp.array(u0)
232
- mse_initial_condition = jnp.mean(
233
- self.loss_weights.initial_condition
234
- * jnp.sum(
271
+ initial_condition_fun = lambda p: jnp.mean(
272
+ jnp.sum(
235
273
  (
236
274
  v_u(
237
275
  t0,
238
- _set_derivatives(
239
- params, self.derivative_keys.initial_condition
240
- ),
276
+ _set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
241
277
  )
242
278
  - u0
243
279
  )
@@ -246,323 +282,71 @@ class LossODE(_LossODEAbstract):
246
282
  )
247
283
  )
248
284
  else:
249
- mse_initial_condition = jnp.array(0.0)
285
+ initial_condition_fun = None
250
286
 
251
287
  if batch.obs_batch_dict is not None:
252
288
  # update params with the batches of observed params
253
- params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
289
+ params_obs = _update_eq_params_dict(
290
+ params, batch.obs_batch_dict["eq_params"]
291
+ )
254
292
 
255
293
  # MSE loss wrt to an observed batch
256
- mse_observation_loss = observations_loss_apply(
294
+ obs_loss_fun = lambda po: observations_loss_apply(
257
295
  self.u,
258
- (batch.obs_batch_dict["pinn_in"],),
259
- _set_derivatives(params, self.derivative_keys.observations),
296
+ batch.obs_batch_dict["pinn_in"],
297
+ _set_derivatives(po, self.derivative_keys.observations), # type: ignore
260
298
  self.vmap_in_axes + vmap_in_axes_params,
261
299
  batch.obs_batch_dict["val"],
262
- self.loss_weights.observations,
263
300
  self.obs_slice,
264
301
  )
265
302
  else:
266
- mse_observation_loss = jnp.array(0.0)
267
-
268
- # total loss
269
- total_loss = mse_dyn_loss + mse_initial_condition + mse_observation_loss
270
- return total_loss, (
271
- {
272
- "dyn_loss": mse_dyn_loss,
273
- "initial_condition": mse_initial_condition,
274
- "observations": mse_observation_loss,
275
- }
276
- )
277
-
278
-
279
- class SystemLossODE(eqx.Module):
280
- r"""
281
- Class to implement a system of ODEs.
282
- The goal is to give maximum freedom to the user. The class is created with
283
- a dict of dynamic loss and a dict of initial conditions. Then, it iterates
284
- over the dynamic losses that compose the system. All PINNs are passed as
285
- arguments to each dynamic loss evaluate functions, along with all the
286
- parameter dictionaries. All specification is left to the responsability
287
- of the user, inside the dynamic loss.
288
-
289
- **Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
290
- Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
291
- solution.
303
+ params_obs = None
304
+ obs_loss_fun = None
292
305
 
293
- Parameters
294
- ----------
295
- u_dict : Dict[str, eqx.Module]
296
- dict of PINNs
297
- loss_weights : LossWeightsODEDict
298
- A dictionary of LossWeightsODE
299
- derivative_keys_dict : Dict[str, DerivativeKeysODE], default=None
300
- A dictionnary of DerivativeKeysODE specifying what field of `params`
301
- should be used during gradient computations for each of the terms of
302
- the total loss, for each of the loss in the system. Default is
303
- `"nn_params`" everywhere.
304
- initial_condition_dict : Dict[str, tuple], default=None
305
- dict of tuple of length 2 with initial condition $(t_0, u_0)$
306
- Must share the keys of `u_dict`. Default is None. No initial
307
- condition is permitted when the initial condition is hardcoded in
308
- the PINN architecture for example
309
- dynamic_loss_dict : Dict[str, ODE]
310
- dict of dynamic part of the loss, basically the differential
311
- operator $\mathcal{N}[u](t)$. Should implement a method
312
- `dynamic_loss.evaluate(t, u, params)`
313
- obs_slice_dict : Dict[str, Slice]
314
- dict of obs_slice, with keys from `u_dict` to designate the
315
- output(s) channels that are observed, for each
316
- PINNs. Default is None. But if a value is given, all the entries of
317
- `u_dict` must be represented here with default value `jnp.s_[...]`
318
- if no particular slice is to be given.
319
- params_dict : InitVar[ParamsDict], default=None
320
- The main Params object of the problem needed to instanciate the
321
- DerivativeKeysODE if the latter is not specified.
322
-
323
- Raises
324
- ------
325
- ValueError
326
- if initial condition is not a dict of tuple.
327
- ValueError
328
- if the dictionaries that should share the keys of u_dict do not.
329
- """
330
-
331
- # NOTE static=True only for leaf attributes that are not valid JAX types
332
- # (ie. jax.Array cannot be static) and that we do not expect to change
333
- u_dict: Dict[str, eqx.Module]
334
- dynamic_loss_dict: Dict[str, ODE]
335
- derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
336
- kw_only=True, default=None
337
- )
338
- initial_condition_dict: Dict[str, tuple] | None = eqx.field(
339
- kw_only=True, default=None
340
- )
341
-
342
- obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
343
- kw_only=True, default=None, static=True
344
- ) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
345
- # we do not expect it to change with put a static=True here. But note that
346
- # this is the only static for all the SystemLossODE attribute, since all
347
- # other are composed of more complex structures ("non-leaf")
348
-
349
- # For the user loss_weights are passed as a LossWeightsODEDict (with internal
350
- # dictionary having keys in u_dict and / or dynamic_loss_dict)
351
- loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
352
- kw_only=True, default=None
353
- )
354
- params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
355
-
356
- u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
357
- derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
358
-
359
- u_dict_with_none: Dict[str, None] = eqx.field(init=False)
360
- # internally the loss weights are handled with a dictionary
361
- _loss_weights: Dict[str, dict] = eqx.field(init=False)
362
-
363
- def __post_init__(self, loss_weights=None, params_dict=None):
364
- # a dictionary that will be useful at different places
365
- self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
366
- if self.initial_condition_dict is None:
367
- self.initial_condition_dict = self.u_dict_with_none
368
- else:
369
- if self.u_dict.keys() != self.initial_condition_dict.keys():
370
- raise ValueError(
371
- "initial_condition_dict should have same keys as u_dict"
372
- )
373
- if self.obs_slice_dict is None:
374
- self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
375
- else:
376
- if self.u_dict.keys() != self.obs_slice_dict.keys():
377
- raise ValueError("obs_slice_dict should have same keys as u_dict")
378
-
379
- if self.derivative_keys_dict is None:
380
- self.derivative_keys_dict = {
381
- k: None
382
- for k in set(
383
- list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
384
- )
385
- }
386
- # set() because we can have duplicate entries and in this case we
387
- # say it corresponds to the same derivative_keys_dict entry
388
- # we need both because the constraints (all but dyn_loss) will be
389
- # done by iterating on u_dict while the dyn_loss will be by
390
- # iterating on dynamic_loss_dict. So each time we will require dome
391
- # derivative_keys_dict
392
-
393
- # derivative keys for the u_constraints. Note that we create missing
394
- # DerivativeKeysODE around a Params object and not ParamsDict
395
- # this works because u_dict.keys == params_dict.nn_params.keys()
396
- for k in self.u_dict.keys():
397
- if self.derivative_keys_dict[k] is None:
398
- self.derivative_keys_dict[k] = DerivativeKeysODE(
399
- params=params_dict.extract_params(k)
400
- )
401
-
402
- self._loss_weights = self.set_loss_weights(loss_weights)
403
-
404
- # The constaints on the solutions will be implemented by reusing a
405
- # LossODE class without dynamic loss term
406
- self.u_constraints_dict = {}
407
- for i in self.u_dict.keys():
408
- self.u_constraints_dict[i] = LossODE(
409
- u=self.u_dict[i],
410
- loss_weights=LossWeightsODE(
411
- dyn_loss=0.0,
412
- initial_condition=1.0,
413
- observations=1.0,
414
- ),
415
- dynamic_loss=None,
416
- derivative_keys=self.derivative_keys_dict[i],
417
- initial_condition=self.initial_condition_dict[i],
418
- obs_slice=self.obs_slice_dict[i],
419
- )
420
-
421
- # derivative keys for the dynamic loss. Note that we create a
422
- # DerivativeKeysODE around a ParamsDict object because a whole
423
- # params_dict is feed to DynamicLoss.evaluate functions (extract_params
424
- # happen inside it)
425
- self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
306
+ # get the unweighted mses for each loss term as well as the gradients
307
+ all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
308
+ ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
309
+ )
310
+ all_params: ODEComponents[Params[Array] | None] = ODEComponents(
311
+ params, params, params_obs
312
+ )
313
+ mses_grads = jax.tree.map(
314
+ lambda fun, params: self.get_gradients(fun, params),
315
+ all_funs,
316
+ all_params,
317
+ is_leaf=lambda x: x is None,
318
+ )
426
319
 
427
- def set_loss_weights(self, loss_weights_init):
428
- """
429
- This rather complex function enables the user to specify a simple
430
- loss_weights=LossWeightsODEDict(dyn_loss=1., initial_condition=Tmax)
431
- for ponderating values being applied to all the equations of the
432
- system... So all the transformations are handled here
433
- """
434
- _loss_weights = {}
435
- for k in fields(loss_weights_init):
436
- v = getattr(loss_weights_init, k.name)
437
- if isinstance(v, dict):
438
- for vv in v.values():
439
- if not isinstance(vv, (int, float)) and not (
440
- isinstance(vv, Array)
441
- and ((vv.shape == (1,) or len(vv.shape) == 0))
442
- ):
443
- # TODO improve that
444
- raise ValueError(
445
- f"loss values cannot be vectorial here, got {vv}"
446
- )
447
- if k.name == "dyn_loss":
448
- if v.keys() == self.dynamic_loss_dict.keys():
449
- _loss_weights[k.name] = v
450
- else:
451
- raise ValueError(
452
- "Keys in nested dictionary of loss_weights"
453
- " do not match dynamic_loss_dict keys"
454
- )
455
- else:
456
- if v.keys() == self.u_dict.keys():
457
- _loss_weights[k.name] = v
458
- else:
459
- raise ValueError(
460
- "Keys in nested dictionary of loss_weights"
461
- " do not match u_dict keys"
462
- )
463
- elif v is None:
464
- _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
465
- else:
466
- if not isinstance(v, (int, float)) and not (
467
- isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
468
- ):
469
- # TODO improve that
470
- raise ValueError(f"loss values cannot be vectorial here, got {v}")
471
- if k.name == "dyn_loss":
472
- _loss_weights[k.name] = {
473
- kk: v for kk in self.dynamic_loss_dict.keys()
474
- }
475
- else:
476
- _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
477
-
478
- return _loss_weights
320
+ mses = jax.tree.map(
321
+ lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
322
+ )
323
+ grads = jax.tree.map(
324
+ lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
325
+ )
479
326
 
480
- def __call__(self, *args, **kwargs):
481
- return self.evaluate(*args, **kwargs)
327
+ return mses, grads
482
328
 
483
- def evaluate(self, params_dict: ParamsDict, batch: ODEBatch) -> Float[Array, "1"]:
329
+ def evaluate(
330
+ self, params: Params[Array], batch: ODEBatch
331
+ ) -> tuple[Float[Array, " "], ODEComponents[Float[Array, " "] | None]]:
484
332
  """
485
333
  Evaluate the loss function at a batch of points for given parameters.
486
334
 
335
+ We retrieve the total value itself and a PyTree with loss values for each term
487
336
 
488
337
  Parameters
489
338
  ---------
490
339
  params
491
- A ParamsDict object
340
+ Parameters at which the loss is evaluated
492
341
  batch
493
- A ODEBatch object.
494
- Such a named tuple is composed of a batch of time points
495
- at which to evaluate an optional additional batch of parameters (eg. for
342
+ Composed of a batch of points in the
343
+ domain, a batch of points in the domain
344
+ border and an optional additional batch of parameters (eg. for
496
345
  metamodeling) and an optional additional batch of observed
497
346
  inputs/outputs/parameters
498
347
  """
499
- if (
500
- isinstance(params_dict.nn_params, dict)
501
- and self.u_dict.keys() != params_dict.nn_params.keys()
502
- ):
503
- raise ValueError("u_dict and params_dict.nn_params should have same keys ")
504
-
505
- temporal_batch = batch.temporal_batch
506
-
507
- vmap_in_axes_t = (0,)
508
-
509
- # Retrieve the optional eq_params_batch
510
- # and update eq_params with the latter
511
- # and update vmap_in_axes
512
- if batch.param_batch_dict is not None:
513
- # update params with the batches of generated params
514
- params = _update_eq_params_dict(params, batch.param_batch_dict)
515
-
516
- vmap_in_axes_params = _get_vmap_in_axes_params(
517
- batch.param_batch_dict, params_dict
518
- )
519
-
520
- def dyn_loss_for_one_key(dyn_loss, loss_weight):
521
- """This function is used in tree_map"""
522
- return dynamic_loss_apply(
523
- dyn_loss.evaluate,
524
- self.u_dict,
525
- temporal_batch,
526
- _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
527
- vmap_in_axes_t + vmap_in_axes_params,
528
- loss_weight,
529
- u_type=PINN,
530
- )
348
+ loss_terms, _ = self.evaluate_by_terms(params, batch)
531
349
 
532
- dyn_loss_mse_dict = jax.tree_util.tree_map(
533
- dyn_loss_for_one_key,
534
- self.dynamic_loss_dict,
535
- self._loss_weights["dyn_loss"],
536
- is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
537
- # where plain (unregister pytree) node classes, we could not traverse
538
- # this level. Now that dynamic losses are eqx.Module they can be
539
- # traversed by tree map recursion. Hence we need to specify to that
540
- # we want to stop at this level
541
- )
542
- mse_dyn_loss = jax.tree_util.tree_reduce(
543
- lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
544
- )
545
-
546
- # initial conditions and observation_loss via the internal LossODE
547
- loss_weight_struct = {
548
- "dyn_loss": "*",
549
- "observations": "*",
550
- "initial_condition": "*",
551
- }
552
-
553
- # we need to do the following for the tree_mapping to work
554
- if batch.obs_batch_dict is None:
555
- batch = append_obs_batch(batch, self.u_dict_with_none)
556
-
557
- total_loss, res_dict = constraints_system_loss_apply(
558
- self.u_constraints_dict,
559
- batch,
560
- params_dict,
561
- self._loss_weights,
562
- loss_weight_struct,
563
- )
350
+ loss_val = self.ponderate_and_sum_loss(loss_terms)
564
351
 
565
- # Add the mse_dyn_loss from the previous computations
566
- total_loss += mse_dyn_loss
567
- res_dict["dyn_loss"] += mse_dyn_loss
568
- return total_loss, res_dict
352
+ return loss_val, loss_terms