jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py CHANGED
@@ -1,24 +1,23 @@
1
- # pylint: disable=unsubscriptable-object, no-member
2
1
  """
3
2
  Main module to implement a ODE loss in jinns
4
3
  """
4
+
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from dataclasses import InitVar, fields
10
- from typing import TYPE_CHECKING, Dict
9
+ from dataclasses import InitVar
10
+ from typing import TYPE_CHECKING, TypedDict
11
+ from types import EllipsisType
11
12
  import abc
12
13
  import warnings
13
14
  import jax
14
15
  import jax.numpy as jnp
15
16
  from jax import vmap
16
17
  import equinox as eqx
17
- from jaxtyping import Float, Array, Int
18
- from jinns.data._DataGenerators import append_obs_batch
18
+ from jaxtyping import Float, Array
19
19
  from jinns.loss._loss_utils import (
20
20
  dynamic_loss_apply,
21
- constraints_system_loss_apply,
22
21
  observations_loss_apply,
23
22
  )
24
23
  from jinns.parameters._params import (
@@ -26,15 +25,23 @@ from jinns.parameters._params import (
26
25
  _update_eq_params_dict,
27
26
  )
28
27
  from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
29
- from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
30
- from jinns.loss._DynamicLossAbstract import ODE
31
- from jinns.nn._pinn import PINN
28
+ from jinns.loss._loss_weights import LossWeightsODE
29
+ from jinns.loss._abstract_loss import AbstractLoss
32
30
 
33
31
  if TYPE_CHECKING:
34
- from jinns.utils._types import *
32
+ # imports only used in type hints
33
+ from jinns.parameters._params import Params
34
+ from jinns.data._Batchs import ODEBatch
35
+ from jinns.nn._abstract_pinn import AbstractPINN
36
+ from jinns.loss import ODE
37
+
38
+ class LossDictODE(TypedDict):
39
+ dyn_loss: Float[Array, " "]
40
+ initial_condition: Float[Array, " "]
41
+ observations: Float[Array, " "]
35
42
 
36
43
 
37
- class _LossODEAbstract(eqx.Module):
44
+ class _LossODEAbstract(AbstractLoss):
38
45
  """
39
46
  Parameters
40
47
  ----------
@@ -49,14 +56,14 @@ class _LossODEAbstract(eqx.Module):
49
56
  Fields can be "nn_params", "eq_params" or "both". Those that should not
50
57
  be updated will have a `jax.lax.stop_gradient` called on them. Default
51
58
  is `"nn_params"` for each composant of the loss.
52
- initial_condition : tuple, default=None
59
+ initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
53
60
  tuple of length 2 with initial condition $(t_0, u_0)$.
54
- obs_slice : Slice, default=None
61
+ obs_slice : EllipsisType | slice | None, default=None
55
62
  Slice object specifying the begininning/ending
56
63
  slice of u output(s) that is observed. This is useful for
57
64
  multidimensional PINN, with partially observed outputs.
58
65
  Default is None (whole output is observed).
59
- params : InitVar[Params], default=None
66
+ params : InitVar[Params[Array]], default=None
60
67
  The main Params object of the problem needed to instanciate the
61
68
  DerivativeKeysODE if the latter is not specified.
62
69
  """
@@ -66,24 +73,27 @@ class _LossODEAbstract(eqx.Module):
66
73
  # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
67
74
  derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
68
75
  loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
69
- initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
70
- obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
76
+ initial_condition: (
77
+ tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
78
+ ) = eqx.field(kw_only=True, default=None)
79
+ obs_slice: EllipsisType | slice | None = eqx.field(
80
+ kw_only=True, default=None, static=True
81
+ )
71
82
 
72
- params: InitVar[Params] = eqx.field(default=None, kw_only=True)
83
+ params: InitVar[Params[Array]] = eqx.field(default=None, kw_only=True)
73
84
 
74
- def __post_init__(self, params=None):
85
+ def __post_init__(self, params: Params[Array] | None = None):
75
86
  if self.loss_weights is None:
76
87
  self.loss_weights = LossWeightsODE()
77
88
 
78
89
  if self.derivative_keys is None:
79
- try:
80
- # be default we only take gradient wrt nn_params
81
- self.derivative_keys = DerivativeKeysODE(params=params)
82
- except ValueError as exc:
90
+ # by default we only take gradient wrt nn_params
91
+ if params is None:
83
92
  raise ValueError(
84
93
  "Problem at self.derivative_keys initialization "
85
94
  f"received {self.derivative_keys=} and {params=}"
86
- ) from exc
95
+ )
96
+ self.derivative_keys = DerivativeKeysODE(params=params)
87
97
  if self.initial_condition is None:
88
98
  warnings.warn(
89
99
  "Initial condition wasn't provided. Be sure to cover for that"
@@ -98,6 +108,19 @@ class _LossODEAbstract(eqx.Module):
98
108
  "Initial condition should be a tuple of len 2 with (t0, u0), "
99
109
  f"{self.initial_condition} was passed."
100
110
  )
111
+ # some checks/reshaping for t0
112
+ t0, u0 = self.initial_condition
113
+ if isinstance(t0, Array):
114
+ if not t0.shape: # e.g. user input: jnp.array(0.)
115
+ t0 = jnp.array([t0])
116
+ elif t0.shape != (1,):
117
+ raise ValueError(
118
+ f"Wrong t0 input (self.initial_condition[0]) It should be"
119
+ f"a float or an array of shape (1,). Got shape: {t0.shape}"
120
+ )
121
+ if isinstance(t0, float): # e.g. user input: 0
122
+ t0 = jnp.array([t0])
123
+ self.initial_condition = (t0, u0)
101
124
 
102
125
  if self.obs_slice is None:
103
126
  self.obs_slice = jnp.s_[...]
@@ -105,10 +128,14 @@ class _LossODEAbstract(eqx.Module):
105
128
  if self.loss_weights is None:
106
129
  self.loss_weights = LossWeightsODE()
107
130
 
131
+ @abc.abstractmethod
132
+ def __call__(self, *_, **__):
133
+ pass
134
+
108
135
  @abc.abstractmethod
109
136
  def evaluate(
110
- self: eqx.Module, params: Params, batch: ODEBatch
111
- ) -> tuple[Float, dict]:
137
+ self: eqx.Module, params: Params[Array], batch: ODEBatch
138
+ ) -> tuple[Float[Array, " "], LossDictODE]:
112
139
  raise NotImplementedError
113
140
 
114
141
 
@@ -135,19 +162,19 @@ class LossODE(_LossODEAbstract):
135
162
  Fields can be "nn_params", "eq_params" or "both". Those that should not
136
163
  be updated will have a `jax.lax.stop_gradient` called on them. Default
137
164
  is `"nn_params"` for each composant of the loss.
138
- initial_condition : tuple, default=None
165
+ initial_condition : tuple[float | Float[Array, " 1"]], default=None
139
166
  tuple of length 2 with initial condition $(t_0, u_0)$.
140
- obs_slice Slice, default=None
167
+ obs_slice : EllipsisType | slice | None, default=None
141
168
  Slice object specifying the begininning/ending
142
169
  slice of u output(s) that is observed. This is useful for
143
170
  multidimensional PINN, with partially observed outputs.
144
171
  Default is None (whole output is observed).
145
- params : InitVar[Params], default=None
172
+ params : InitVar[Params[Array]], default=None
146
173
  The main Params object of the problem needed to instanciate the
147
174
  DerivativeKeysODE if the latter is not specified.
148
175
  u : eqx.Module
149
176
  the PINN
150
- dynamic_loss : DynamicLoss
177
+ dynamic_loss : ODE
151
178
  the ODE dynamic part of the loss, basically the differential
152
179
  operator $\mathcal{N}[u](t)$. Should implement a method
153
180
  `dynamic_loss.evaluate(t, u, params)`.
@@ -161,12 +188,12 @@ class LossODE(_LossODEAbstract):
161
188
 
162
189
  # NOTE static=True only for leaf attributes that are not valid JAX types
163
190
  # (ie. jax.Array cannot be static) and that we do not expect to change
164
- u: eqx.Module
165
- dynamic_loss: DynamicLoss | None
191
+ u: AbstractPINN
192
+ dynamic_loss: ODE | None
166
193
 
167
- vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
194
+ vmap_in_axes: tuple[int] = eqx.field(init=False, static=True)
168
195
 
169
- def __post_init__(self, params=None):
196
+ def __post_init__(self, params: Params[Array] | None = None):
170
197
  super().__post_init__(
171
198
  params=params
172
199
  ) # because __init__ or __post_init__ of Base
@@ -178,8 +205,8 @@ class LossODE(_LossODEAbstract):
178
205
  return self.evaluate(*args, **kwargs)
179
206
 
180
207
  def evaluate(
181
- self, params: Params, batch: ODEBatch
182
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
208
+ self, params: Params[Array], batch: ODEBatch
209
+ ) -> tuple[Float[Array, " "], LossDictODE]:
183
210
  """
184
211
  Evaluate the loss function at a batch of points for given parameters.
185
212
 
@@ -210,9 +237,9 @@ class LossODE(_LossODEAbstract):
210
237
  self.dynamic_loss.evaluate,
211
238
  self.u,
212
239
  temporal_batch,
213
- _set_derivatives(params, self.derivative_keys.dyn_loss),
240
+ _set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
214
241
  self.vmap_in_axes + vmap_in_axes_params,
215
- self.loss_weights.dyn_loss,
242
+ self.loss_weights.dyn_loss, # type: ignore
216
243
  )
217
244
  else:
218
245
  mse_dyn_loss = jnp.array(0.0)
@@ -226,17 +253,17 @@ class LossODE(_LossODEAbstract):
226
253
  v_u = self.u
227
254
  else:
228
255
  v_u = vmap(self.u, (None,) + vmap_in_axes_params)
229
- t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
230
- t0 = jnp.array([t0])
256
+ t0, u0 = self.initial_condition
231
257
  u0 = jnp.array(u0)
232
258
  mse_initial_condition = jnp.mean(
233
- self.loss_weights.initial_condition
259
+ self.loss_weights.initial_condition # type: ignore
234
260
  * jnp.sum(
235
261
  (
236
262
  v_u(
237
263
  t0,
238
264
  _set_derivatives(
239
- params, self.derivative_keys.initial_condition
265
+ params,
266
+ self.derivative_keys.initial_condition, # type: ignore
240
267
  ),
241
268
  )
242
269
  - u0
@@ -255,11 +282,11 @@ class LossODE(_LossODEAbstract):
255
282
  # MSE loss wrt to an observed batch
256
283
  mse_observation_loss = observations_loss_apply(
257
284
  self.u,
258
- (batch.obs_batch_dict["pinn_in"],),
259
- _set_derivatives(params, self.derivative_keys.observations),
285
+ batch.obs_batch_dict["pinn_in"],
286
+ _set_derivatives(params, self.derivative_keys.observations), # type: ignore
260
287
  self.vmap_in_axes + vmap_in_axes_params,
261
288
  batch.obs_batch_dict["val"],
262
- self.loss_weights.observations,
289
+ self.loss_weights.observations, # type: ignore
263
290
  self.obs_slice,
264
291
  )
265
292
  else:
@@ -274,295 +301,3 @@ class LossODE(_LossODEAbstract):
274
301
  "observations": mse_observation_loss,
275
302
  }
276
303
  )
277
-
278
-
279
- class SystemLossODE(eqx.Module):
280
- r"""
281
- Class to implement a system of ODEs.
282
- The goal is to give maximum freedom to the user. The class is created with
283
- a dict of dynamic loss and a dict of initial conditions. Then, it iterates
284
- over the dynamic losses that compose the system. All PINNs are passed as
285
- arguments to each dynamic loss evaluate functions, along with all the
286
- parameter dictionaries. All specification is left to the responsability
287
- of the user, inside the dynamic loss.
288
-
289
- **Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
290
- Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
291
- solution.
292
-
293
- Parameters
294
- ----------
295
- u_dict : Dict[str, eqx.Module]
296
- dict of PINNs
297
- loss_weights : LossWeightsODEDict
298
- A dictionary of LossWeightsODE
299
- derivative_keys_dict : Dict[str, DerivativeKeysODE], default=None
300
- A dictionnary of DerivativeKeysODE specifying what field of `params`
301
- should be used during gradient computations for each of the terms of
302
- the total loss, for each of the loss in the system. Default is
303
- `"nn_params`" everywhere.
304
- initial_condition_dict : Dict[str, tuple], default=None
305
- dict of tuple of length 2 with initial condition $(t_0, u_0)$
306
- Must share the keys of `u_dict`. Default is None. No initial
307
- condition is permitted when the initial condition is hardcoded in
308
- the PINN architecture for example
309
- dynamic_loss_dict : Dict[str, ODE]
310
- dict of dynamic part of the loss, basically the differential
311
- operator $\mathcal{N}[u](t)$. Should implement a method
312
- `dynamic_loss.evaluate(t, u, params)`
313
- obs_slice_dict : Dict[str, Slice]
314
- dict of obs_slice, with keys from `u_dict` to designate the
315
- output(s) channels that are observed, for each
316
- PINNs. Default is None. But if a value is given, all the entries of
317
- `u_dict` must be represented here with default value `jnp.s_[...]`
318
- if no particular slice is to be given.
319
- params_dict : InitVar[ParamsDict], default=None
320
- The main Params object of the problem needed to instanciate the
321
- DerivativeKeysODE if the latter is not specified.
322
-
323
- Raises
324
- ------
325
- ValueError
326
- if initial condition is not a dict of tuple.
327
- ValueError
328
- if the dictionaries that should share the keys of u_dict do not.
329
- """
330
-
331
- # NOTE static=True only for leaf attributes that are not valid JAX types
332
- # (ie. jax.Array cannot be static) and that we do not expect to change
333
- u_dict: Dict[str, eqx.Module]
334
- dynamic_loss_dict: Dict[str, ODE]
335
- derivative_keys_dict: Dict[str, DerivativeKeysODE | None] | None = eqx.field(
336
- kw_only=True, default=None
337
- )
338
- initial_condition_dict: Dict[str, tuple] | None = eqx.field(
339
- kw_only=True, default=None
340
- )
341
-
342
- obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
343
- kw_only=True, default=None, static=True
344
- ) # We are at an "leaf" attribute here (slice, not valid JAX type). Since
345
- # we do not expect it to change with put a static=True here. But note that
346
- # this is the only static for all the SystemLossODE attribute, since all
347
- # other are composed of more complex structures ("non-leaf")
348
-
349
- # For the user loss_weights are passed as a LossWeightsODEDict (with internal
350
- # dictionary having keys in u_dict and / or dynamic_loss_dict)
351
- loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
352
- kw_only=True, default=None
353
- )
354
- params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
355
-
356
- u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
357
- derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
358
-
359
- u_dict_with_none: Dict[str, None] = eqx.field(init=False)
360
- # internally the loss weights are handled with a dictionary
361
- _loss_weights: Dict[str, dict] = eqx.field(init=False)
362
-
363
- def __post_init__(self, loss_weights=None, params_dict=None):
364
- # a dictionary that will be useful at different places
365
- self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
366
- if self.initial_condition_dict is None:
367
- self.initial_condition_dict = self.u_dict_with_none
368
- else:
369
- if self.u_dict.keys() != self.initial_condition_dict.keys():
370
- raise ValueError(
371
- "initial_condition_dict should have same keys as u_dict"
372
- )
373
- if self.obs_slice_dict is None:
374
- self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
375
- else:
376
- if self.u_dict.keys() != self.obs_slice_dict.keys():
377
- raise ValueError("obs_slice_dict should have same keys as u_dict")
378
-
379
- if self.derivative_keys_dict is None:
380
- self.derivative_keys_dict = {
381
- k: None
382
- for k in set(
383
- list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
384
- )
385
- }
386
- # set() because we can have duplicate entries and in this case we
387
- # say it corresponds to the same derivative_keys_dict entry
388
- # we need both because the constraints (all but dyn_loss) will be
389
- # done by iterating on u_dict while the dyn_loss will be by
390
- # iterating on dynamic_loss_dict. So each time we will require dome
391
- # derivative_keys_dict
392
-
393
- # derivative keys for the u_constraints. Note that we create missing
394
- # DerivativeKeysODE around a Params object and not ParamsDict
395
- # this works because u_dict.keys == params_dict.nn_params.keys()
396
- for k in self.u_dict.keys():
397
- if self.derivative_keys_dict[k] is None:
398
- self.derivative_keys_dict[k] = DerivativeKeysODE(
399
- params=params_dict.extract_params(k)
400
- )
401
-
402
- self._loss_weights = self.set_loss_weights(loss_weights)
403
-
404
- # The constaints on the solutions will be implemented by reusing a
405
- # LossODE class without dynamic loss term
406
- self.u_constraints_dict = {}
407
- for i in self.u_dict.keys():
408
- self.u_constraints_dict[i] = LossODE(
409
- u=self.u_dict[i],
410
- loss_weights=LossWeightsODE(
411
- dyn_loss=0.0,
412
- initial_condition=1.0,
413
- observations=1.0,
414
- ),
415
- dynamic_loss=None,
416
- derivative_keys=self.derivative_keys_dict[i],
417
- initial_condition=self.initial_condition_dict[i],
418
- obs_slice=self.obs_slice_dict[i],
419
- )
420
-
421
- # derivative keys for the dynamic loss. Note that we create a
422
- # DerivativeKeysODE around a ParamsDict object because a whole
423
- # params_dict is feed to DynamicLoss.evaluate functions (extract_params
424
- # happen inside it)
425
- self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
426
-
427
- def set_loss_weights(self, loss_weights_init):
428
- """
429
- This rather complex function enables the user to specify a simple
430
- loss_weights=LossWeightsODEDict(dyn_loss=1., initial_condition=Tmax)
431
- for ponderating values being applied to all the equations of the
432
- system... So all the transformations are handled here
433
- """
434
- _loss_weights = {}
435
- for k in fields(loss_weights_init):
436
- v = getattr(loss_weights_init, k.name)
437
- if isinstance(v, dict):
438
- for vv in v.values():
439
- if not isinstance(vv, (int, float)) and not (
440
- isinstance(vv, Array)
441
- and ((vv.shape == (1,) or len(vv.shape) == 0))
442
- ):
443
- # TODO improve that
444
- raise ValueError(
445
- f"loss values cannot be vectorial here, got {vv}"
446
- )
447
- if k.name == "dyn_loss":
448
- if v.keys() == self.dynamic_loss_dict.keys():
449
- _loss_weights[k.name] = v
450
- else:
451
- raise ValueError(
452
- "Keys in nested dictionary of loss_weights"
453
- " do not match dynamic_loss_dict keys"
454
- )
455
- else:
456
- if v.keys() == self.u_dict.keys():
457
- _loss_weights[k.name] = v
458
- else:
459
- raise ValueError(
460
- "Keys in nested dictionary of loss_weights"
461
- " do not match u_dict keys"
462
- )
463
- elif v is None:
464
- _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
465
- else:
466
- if not isinstance(v, (int, float)) and not (
467
- isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
468
- ):
469
- # TODO improve that
470
- raise ValueError(f"loss values cannot be vectorial here, got {v}")
471
- if k.name == "dyn_loss":
472
- _loss_weights[k.name] = {
473
- kk: v for kk in self.dynamic_loss_dict.keys()
474
- }
475
- else:
476
- _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
477
-
478
- return _loss_weights
479
-
480
- def __call__(self, *args, **kwargs):
481
- return self.evaluate(*args, **kwargs)
482
-
483
- def evaluate(self, params_dict: ParamsDict, batch: ODEBatch) -> Float[Array, "1"]:
484
- """
485
- Evaluate the loss function at a batch of points for given parameters.
486
-
487
-
488
- Parameters
489
- ---------
490
- params
491
- A ParamsDict object
492
- batch
493
- A ODEBatch object.
494
- Such a named tuple is composed of a batch of time points
495
- at which to evaluate an optional additional batch of parameters (eg. for
496
- metamodeling) and an optional additional batch of observed
497
- inputs/outputs/parameters
498
- """
499
- if (
500
- isinstance(params_dict.nn_params, dict)
501
- and self.u_dict.keys() != params_dict.nn_params.keys()
502
- ):
503
- raise ValueError("u_dict and params_dict.nn_params should have same keys ")
504
-
505
- temporal_batch = batch.temporal_batch
506
-
507
- vmap_in_axes_t = (0,)
508
-
509
- # Retrieve the optional eq_params_batch
510
- # and update eq_params with the latter
511
- # and update vmap_in_axes
512
- if batch.param_batch_dict is not None:
513
- # update params with the batches of generated params
514
- params = _update_eq_params_dict(params, batch.param_batch_dict)
515
-
516
- vmap_in_axes_params = _get_vmap_in_axes_params(
517
- batch.param_batch_dict, params_dict
518
- )
519
-
520
- def dyn_loss_for_one_key(dyn_loss, loss_weight):
521
- """This function is used in tree_map"""
522
- return dynamic_loss_apply(
523
- dyn_loss.evaluate,
524
- self.u_dict,
525
- temporal_batch,
526
- _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
527
- vmap_in_axes_t + vmap_in_axes_params,
528
- loss_weight,
529
- u_type=PINN,
530
- )
531
-
532
- dyn_loss_mse_dict = jax.tree_util.tree_map(
533
- dyn_loss_for_one_key,
534
- self.dynamic_loss_dict,
535
- self._loss_weights["dyn_loss"],
536
- is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
537
- # where plain (unregister pytree) node classes, we could not traverse
538
- # this level. Now that dynamic losses are eqx.Module they can be
539
- # traversed by tree map recursion. Hence we need to specify to that
540
- # we want to stop at this level
541
- )
542
- mse_dyn_loss = jax.tree_util.tree_reduce(
543
- lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
544
- )
545
-
546
- # initial conditions and observation_loss via the internal LossODE
547
- loss_weight_struct = {
548
- "dyn_loss": "*",
549
- "observations": "*",
550
- "initial_condition": "*",
551
- }
552
-
553
- # we need to do the following for the tree_mapping to work
554
- if batch.obs_batch_dict is None:
555
- batch = append_obs_batch(batch, self.u_dict_with_none)
556
-
557
- total_loss, res_dict = constraints_system_loss_apply(
558
- self.u_constraints_dict,
559
- batch,
560
- params_dict,
561
- self._loss_weights,
562
- loss_weight_struct,
563
- )
564
-
565
- # Add the mse_dyn_loss from the previous computations
566
- total_loss += mse_dyn_loss
567
- res_dict["dyn_loss"] += mse_dyn_loss
568
- return total_loss, res_dict