jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py CHANGED
@@ -1,14 +1,15 @@
1
- # pylint: disable=unsubscriptable-object, no-member
2
1
  """
3
2
  Main module to implement a PDE loss in jinns
4
3
  """
4
+
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  import abc
10
- from dataclasses import InitVar, fields
11
- from typing import TYPE_CHECKING, Dict, Callable
10
+ from dataclasses import InitVar
11
+ from typing import TYPE_CHECKING, Callable, TypedDict
12
+ from types import EllipsisType
12
13
  import warnings
13
14
  import jax
14
15
  import jax.numpy as jnp
@@ -20,9 +21,7 @@ from jinns.loss._loss_utils import (
20
21
  normalization_loss_apply,
21
22
  observations_loss_apply,
22
23
  initial_condition_apply,
23
- constraints_system_loss_apply,
24
24
  )
25
- from jinns.data._DataGenerators import append_obs_batch
26
25
  from jinns.parameters._params import (
27
26
  _get_vmap_in_axes_params,
28
27
  _update_eq_params_dict,
@@ -32,19 +31,31 @@ from jinns.parameters._derivative_keys import (
32
31
  DerivativeKeysPDEStatio,
33
32
  DerivativeKeysPDENonStatio,
34
33
  )
34
+ from jinns.loss._abstract_loss import AbstractLoss
35
+ from jinns.loss._loss_components import PDEStatioComponents, PDENonStatioComponents
35
36
  from jinns.loss._loss_weights import (
36
37
  LossWeightsPDEStatio,
37
38
  LossWeightsPDENonStatio,
38
- LossWeightsPDEDict,
39
39
  )
40
- from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
41
- from jinns.nn._pinn import PINN
42
- from jinns.nn._spinn import SPINN
43
40
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
41
+ from jinns.parameters._params import Params
44
42
 
45
43
 
46
44
  if TYPE_CHECKING:
47
- from jinns.utils._types import *
45
+ # imports for type hints only
46
+ from jinns.nn._abstract_pinn import AbstractPINN
47
+ from jinns.loss import PDENonStatio, PDEStatio
48
+ from jinns.utils._types import BoundaryConditionFun
49
+
50
+ class LossDictPDEStatio(TypedDict):
51
+ dyn_loss: Float[Array, " "]
52
+ norm_loss: Float[Array, " "]
53
+ boundary_loss: Float[Array, " "]
54
+ observations: Float[Array, " "]
55
+
56
+ class LossDictPDENonStatio(LossDictPDEStatio):
57
+ initial_condition: Float[Array, " "]
58
+
48
59
 
49
60
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
50
61
  "dirichlet",
@@ -53,8 +64,8 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
53
64
  ]
54
65
 
55
66
 
56
- class _LossPDEAbstract(eqx.Module):
57
- """
67
+ class _LossPDEAbstract(AbstractLoss):
68
+ r"""
58
69
  Parameters
59
70
  ----------
60
71
 
@@ -62,18 +73,21 @@ class _LossPDEAbstract(eqx.Module):
62
73
  The loss weights for the differents term : dynamic loss,
63
74
  initial condition (if LossWeightsPDENonStatio), boundary conditions if
64
75
  any, normalization loss if any and observations if any.
65
- All fields are set to 1.0 by default.
76
+ Can be updated according to a specific algorithm. See
77
+ `update_weight_method`
78
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
79
+ Default is None meaning no update for loss weights. Otherwise a string
66
80
  derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
67
81
  Specify which field of `params` should be differentiated for each
68
82
  composant of the total loss. Particularily useful for inverse problems.
69
83
  Fields can be "nn_params", "eq_params" or "both". Those that should not
70
84
  be updated will have a `jax.lax.stop_gradient` called on them. Default
71
85
  is `"nn_params"` for each composant of the loss.
72
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
86
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
73
87
  The function to be matched in the border condition (can be None) or a
74
88
  dictionary of such functions as values and keys as described
75
89
  in `omega_boundary_condition`.
76
- omega_boundary_condition : str | Dict[str, str], default=None
90
+ omega_boundary_condition : str | dict[str, str], default=None
77
91
  Either None (no condition, by default), or a string defining
78
92
  the boundary condition (Dirichlet or Von Neumann),
79
93
  or a dictionary with such strings as values. In this case,
@@ -84,28 +98,29 @@ class _LossPDEAbstract(eqx.Module):
84
98
  a particular boundary condition on this facet.
85
99
  The facet called “xmin”, resp. “xmax” etc., in 2D,
86
100
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
87
- omega_boundary_dim : slice | Dict[str, slice], default=None
101
+ omega_boundary_dim : slice | dict[str, slice], default=None
88
102
  Either None, or a slice object or a dictionary of slice objects as
89
103
  values and keys as described in `omega_boundary_condition`.
90
104
  `omega_boundary_dim` indicates which dimension(s) of the PINN
91
105
  will be forced to match the boundary condition.
92
106
  Note that it must be a slice and not an integer
93
107
  (but a preprocessing of the user provided argument takes care of it)
94
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
108
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
95
109
  Monte-Carlo sample points for computing the
96
110
  normalization constant. Default is None.
97
- norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
111
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
98
112
  The importance sampling weights for Monte-Carlo integration of the
99
113
  normalization constant. Must be provided if `norm_samples` is provided.
100
- `norm_weights` should have the same leading dimension as
114
+ `norm_weights` should be broadcastble to
101
115
  `norm_samples`.
102
- Alternatively, the user can pass a float or an integer.
116
+ Alternatively, the user can pass a float or an integer that will be
117
+ made broadcastable to `norm_samples`.
103
118
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
104
119
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
105
- obs_slice : slice, default=None
120
+ obs_slice : EllipsisType | slice, default=None
106
121
  slice object specifying the begininning/ending of the PINN output
107
122
  that is observed (this is then useful for multidim PINN). Default is None.
108
- params : InitVar[Params], default=None
123
+ params : InitVar[Params[Array]], default=None
109
124
  The main Params object of the problem needed to instanciate the
110
125
  DerivativeKeysODE if the latter is not specified.
111
126
  """
@@ -119,26 +134,28 @@ class _LossPDEAbstract(eqx.Module):
119
134
  loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
120
135
  kw_only=True, default=None
121
136
  )
122
- omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
123
- kw_only=True, default=None, static=True
124
- )
125
- omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
137
+ omega_boundary_fun: (
138
+ BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
139
+ ) = eqx.field(kw_only=True, default=None, static=True)
140
+ omega_boundary_condition: str | dict[str, str] | None = eqx.field(
126
141
  kw_only=True, default=None, static=True
127
142
  )
128
- omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
143
+ omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
129
144
  kw_only=True, default=None, static=True
130
145
  )
131
- norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
146
+ norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
132
147
  kw_only=True, default=None
133
148
  )
134
- norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
149
+ norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
135
150
  kw_only=True, default=None
136
151
  )
137
- obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
152
+ obs_slice: EllipsisType | slice | None = eqx.field(
153
+ kw_only=True, default=None, static=True
154
+ )
138
155
 
139
- params: InitVar[Params] = eqx.field(kw_only=True, default=None)
156
+ params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
140
157
 
141
- def __post_init__(self, params=None):
158
+ def __post_init__(self, params: Params[Array] | None = None):
142
159
  """
143
160
  Note that neither __init__ or __post_init__ are called when udating a
144
161
  Module with eqx.tree_at
@@ -228,6 +245,11 @@ class _LossPDEAbstract(eqx.Module):
228
245
  )
229
246
 
230
247
  if isinstance(self.omega_boundary_fun, dict):
248
+ if not isinstance(self.omega_boundary_dim, dict):
249
+ raise ValueError(
250
+ "If omega_boundary_fun is a dict then"
251
+ " omega_boundary_dim should also be a dict"
252
+ )
231
253
  if self.omega_boundary_dim is None:
232
254
  self.omega_boundary_dim = {
233
255
  k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
@@ -262,27 +284,29 @@ class _LossPDEAbstract(eqx.Module):
262
284
  raise ValueError(
263
285
  "`norm_weights` must be provided when `norm_samples` is used!"
264
286
  )
265
- try:
266
- assert self.norm_weights.shape[0] == self.norm_samples.shape[0]
267
- except (AssertionError, AttributeError):
268
- if isinstance(self.norm_weights, (int, float)):
269
- self.norm_weights = jnp.array(
270
- [self.norm_weights], dtype=jax.dtypes.canonicalize_dtype(float)
271
- )
272
- else:
287
+ if isinstance(self.norm_weights, (int, float)):
288
+ self.norm_weights = self.norm_weights * jnp.ones(
289
+ (self.norm_samples.shape[0],)
290
+ )
291
+ if isinstance(self.norm_weights, Array):
292
+ if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
273
293
  raise ValueError(
274
- "`norm_weights` should have the same leading dimension"
275
- " as `norm_samples`,"
276
- f" got shape {self.norm_weights.shape} and"
277
- f" shape {self.norm_samples.shape}."
294
+ "self.norm_weights and "
295
+ "self.norm_samples must have the same leading dimension"
278
296
  )
297
+ else:
298
+ raise ValueError("Wrong type for self.norm_weights")
299
+
300
+ @abc.abstractmethod
301
+ def __call__(self, *_, **__):
302
+ pass
279
303
 
280
304
  @abc.abstractmethod
281
305
  def evaluate(
282
306
  self: eqx.Module,
283
- params: Params,
307
+ params: Params[Array],
284
308
  batch: PDEStatioBatch | PDENonStatioBatch,
285
- ) -> tuple[Float, dict]:
309
+ ) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
286
310
  raise NotImplementedError
287
311
 
288
312
 
@@ -299,9 +323,9 @@ class LossPDEStatio(_LossPDEAbstract):
299
323
 
300
324
  Parameters
301
325
  ----------
302
- u : eqx.Module
326
+ u : AbstractPINN
303
327
  the PINN
304
- dynamic_loss : DynamicLoss
328
+ dynamic_loss : PDEStatio
305
329
  the stationary PDE dynamic part of the loss, basically the differential
306
330
  operator $\mathcal{N}[u](x)$. Should implement a method
307
331
  `dynamic_loss.evaluate(x, u, params)`.
@@ -317,18 +341,21 @@ class LossPDEStatio(_LossPDEAbstract):
317
341
  The loss weights for the differents term : dynamic loss,
318
342
  boundary conditions if any, normalization loss if any and
319
343
  observations if any.
320
- All fields are set to 1.0 by default.
344
+ Can be updated according to a specific algorithm. See
345
+ `update_weight_method`
346
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
347
+ Default is None meaning no update for loss weights. Otherwise a string
321
348
  derivative_keys : DerivativeKeysPDEStatio, default=None
322
349
  Specify which field of `params` should be differentiated for each
323
350
  composant of the total loss. Particularily useful for inverse problems.
324
351
  Fields can be "nn_params", "eq_params" or "both". Those that should not
325
352
  be updated will have a `jax.lax.stop_gradient` called on them. Default
326
353
  is `"nn_params"` for each composant of the loss.
327
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
354
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
328
355
  The function to be matched in the border condition (can be None) or a
329
356
  dictionary of such functions as values and keys as described
330
357
  in `omega_boundary_condition`.
331
- omega_boundary_condition : str | Dict[str, str], default=None
358
+ omega_boundary_condition : str | dict[str, str], default=None
332
359
  Either None (no condition, by default), or a string defining
333
360
  the boundary condition (Dirichlet or Von Neumann),
334
361
  or a dictionary with such strings as values. In this case,
@@ -339,17 +366,17 @@ class LossPDEStatio(_LossPDEAbstract):
339
366
  a particular boundary condition on this facet.
340
367
  The facet called “xmin”, resp. “xmax” etc., in 2D,
341
368
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
342
- omega_boundary_dim : slice | Dict[str, slice], default=None
369
+ omega_boundary_dim : slice | dict[str, slice], default=None
343
370
  Either None, or a slice object or a dictionary of slice objects as
344
371
  values and keys as described in `omega_boundary_condition`.
345
372
  `omega_boundary_dim` indicates which dimension(s) of the PINN
346
373
  will be forced to match the boundary condition.
347
374
  Note that it must be a slice and not an integer
348
375
  (but a preprocessing of the user provided argument takes care of it)
349
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
376
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
350
377
  Monte-Carlo sample points for computing the
351
378
  normalization constant. Default is None.
352
- norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
379
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
353
380
  The importance sampling weights for Monte-Carlo integration of the
354
381
  normalization constant. Must be provided if `norm_samples` is provided.
355
382
  `norm_weights` should have the same leading dimension as
@@ -360,7 +387,7 @@ class LossPDEStatio(_LossPDEAbstract):
360
387
  obs_slice : slice, default=None
361
388
  slice object specifying the begininning/ending of the PINN output
362
389
  that is observed (this is then useful for multidim PINN). Default is None.
363
- params : InitVar[Params], default=None
390
+ params : InitVar[Params[Array]], default=None
364
391
  The main Params object of the problem needed to instanciate the
365
392
  DerivativeKeysODE if the latter is not specified.
366
393
 
@@ -375,13 +402,13 @@ class LossPDEStatio(_LossPDEAbstract):
375
402
  # NOTE static=True only for leaf attributes that are not valid JAX types
376
403
  # (ie. jax.Array cannot be static) and that we do not expect to change
377
404
 
378
- u: eqx.Module
379
- dynamic_loss: DynamicLoss | None
405
+ u: AbstractPINN
406
+ dynamic_loss: PDEStatio | None
380
407
  key: Key | None = eqx.field(kw_only=True, default=None)
381
408
 
382
409
  vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
383
410
 
384
- def __post_init__(self, params=None):
411
+ def __post_init__(self, params: Params[Array] | None = None):
385
412
  """
386
413
  Note that neither __init__ or __post_init__ are called when udating a
387
414
  Module with eqx.tree_at!
@@ -395,28 +422,31 @@ class LossPDEStatio(_LossPDEAbstract):
395
422
 
396
423
  def _get_dynamic_loss_batch(
397
424
  self, batch: PDEStatioBatch
398
- ) -> Float[Array, "batch_size dimension"]:
425
+ ) -> Float[Array, " batch_size dimension"]:
399
426
  return batch.domain_batch
400
427
 
401
428
  def _get_normalization_loss_batch(
402
429
  self, _
403
- ) -> Float[Array, "nb_norm_samples dimension"]:
404
- return (self.norm_samples,)
430
+ ) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
431
+ return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
432
+
433
+ # we could have used typing.cast though
405
434
 
406
435
  def _get_observations_loss_batch(
407
436
  self, batch: PDEStatioBatch
408
- ) -> Float[Array, "batch_size obs_dim"]:
409
- return (batch.obs_batch_dict["pinn_in"],)
437
+ ) -> Float[Array, " batch_size obs_dim"]:
438
+ return batch.obs_batch_dict["pinn_in"]
410
439
 
411
440
  def __call__(self, *args, **kwargs):
412
441
  return self.evaluate(*args, **kwargs)
413
442
 
414
- def evaluate(
415
- self, params: Params, batch: PDEStatioBatch
416
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
443
+ def evaluate_by_terms(
444
+ self, params: Params[Array], batch: PDEStatioBatch
445
+ ) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
417
446
  """
418
447
  Evaluate the loss function at a batch of points for given parameters.
419
448
 
449
+ We retrieve two PyTrees with loss values and gradients for each term
420
450
 
421
451
  Parameters
422
452
  ---------
@@ -440,75 +470,112 @@ class LossPDEStatio(_LossPDEAbstract):
440
470
 
441
471
  # dynamic part
442
472
  if self.dynamic_loss is not None:
443
- mse_dyn_loss = dynamic_loss_apply(
444
- self.dynamic_loss.evaluate,
473
+ dyn_loss_fun = lambda p: dynamic_loss_apply(
474
+ self.dynamic_loss.evaluate, # type: ignore
445
475
  self.u,
446
476
  self._get_dynamic_loss_batch(batch),
447
- _set_derivatives(params, self.derivative_keys.dyn_loss),
477
+ _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
448
478
  self.vmap_in_axes + vmap_in_axes_params,
449
- self.loss_weights.dyn_loss,
450
479
  )
451
480
  else:
452
- mse_dyn_loss = jnp.array(0.0)
481
+ dyn_loss_fun = None
453
482
 
454
483
  # normalization part
455
484
  if self.norm_samples is not None:
456
- mse_norm_loss = normalization_loss_apply(
485
+ norm_loss_fun = lambda p: normalization_loss_apply(
457
486
  self.u,
458
487
  self._get_normalization_loss_batch(batch),
459
- _set_derivatives(params, self.derivative_keys.norm_loss),
488
+ _set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
460
489
  vmap_in_axes_params,
461
- self.norm_weights,
462
- self.loss_weights.norm_loss,
490
+ self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
463
491
  )
464
492
  else:
465
- mse_norm_loss = jnp.array(0.0)
493
+ norm_loss_fun = None
466
494
 
467
495
  # boundary part
468
- if self.omega_boundary_condition is not None:
469
- mse_boundary_loss = boundary_condition_apply(
496
+ if (
497
+ self.omega_boundary_condition is not None
498
+ and self.omega_boundary_dim is not None
499
+ and self.omega_boundary_fun is not None
500
+ ): # pyright cannot narrow down the three None otherwise as it is class attribute
501
+ boundary_loss_fun = lambda p: boundary_condition_apply(
470
502
  self.u,
471
503
  batch,
472
- _set_derivatives(params, self.derivative_keys.boundary_loss),
473
- self.omega_boundary_fun,
474
- self.omega_boundary_condition,
475
- self.omega_boundary_dim,
476
- self.loss_weights.boundary_loss,
504
+ _set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
505
+ self.omega_boundary_fun, # type: ignore
506
+ self.omega_boundary_condition, # type: ignore
507
+ self.omega_boundary_dim, # type: ignore
477
508
  )
478
509
  else:
479
- mse_boundary_loss = jnp.array(0.0)
510
+ boundary_loss_fun = None
480
511
 
481
512
  # Observation mse
482
513
  if batch.obs_batch_dict is not None:
483
514
  # update params with the batches of observed params
484
- params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
515
+ params_obs = _update_eq_params_dict(
516
+ params, batch.obs_batch_dict["eq_params"]
517
+ )
485
518
 
486
- mse_observation_loss = observations_loss_apply(
519
+ obs_loss_fun = lambda po: observations_loss_apply(
487
520
  self.u,
488
521
  self._get_observations_loss_batch(batch),
489
- _set_derivatives(params, self.derivative_keys.observations),
522
+ _set_derivatives(po, self.derivative_keys.observations), # type: ignore
490
523
  self.vmap_in_axes + vmap_in_axes_params,
491
524
  batch.obs_batch_dict["val"],
492
- self.loss_weights.observations,
493
525
  self.obs_slice,
494
526
  )
495
527
  else:
496
- mse_observation_loss = jnp.array(0.0)
528
+ params_obs = None
529
+ obs_loss_fun = None
497
530
 
498
- # total loss
499
- total_loss = (
500
- mse_dyn_loss + mse_norm_loss + mse_boundary_loss + mse_observation_loss
531
+ # get the unweighted mses for each loss term as well as the gradients
532
+ all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
533
+ PDEStatioComponents(
534
+ dyn_loss_fun, norm_loss_fun, boundary_loss_fun, obs_loss_fun
535
+ )
536
+ )
537
+ all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
538
+ params, params, params, params_obs
539
+ )
540
+ mses_grads = jax.tree.map(
541
+ lambda fun, params: self.get_gradients(fun, params),
542
+ all_funs,
543
+ all_params,
544
+ is_leaf=lambda x: x is None,
501
545
  )
502
- return total_loss, (
503
- {
504
- "dyn_loss": mse_dyn_loss,
505
- "norm_loss": mse_norm_loss,
506
- "boundary_loss": mse_boundary_loss,
507
- "observations": mse_observation_loss,
508
- "initial_condition": jnp.array(0.0), # for compatibility in the
509
- # tree_map of SystemLoss
510
- }
546
+ mses = jax.tree.map(
547
+ lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
511
548
  )
549
+ grads = jax.tree.map(
550
+ lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
551
+ )
552
+
553
+ return mses, grads
554
+
555
+ def evaluate(
556
+ self, params: Params[Array], batch: PDEStatioBatch
557
+ ) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
558
+ """
559
+ Evaluate the loss function at a batch of points for given parameters.
560
+
561
+ We retrieve the total value itself and a PyTree with loss values for each term
562
+
563
+ Parameters
564
+ ---------
565
+ params
566
+ Parameters at which the loss is evaluated
567
+ batch
568
+ Composed of a batch of points in the
569
+ domain, a batch of points in the domain
570
+ border and an optional additional batch of parameters (eg. for
571
+ metamodeling) and an optional additional batch of observed
572
+ inputs/outputs/parameters
573
+ """
574
+ loss_terms, _ = self.evaluate_by_terms(params, batch)
575
+
576
+ loss_val = self.ponderate_and_sum_loss(loss_terms)
577
+
578
+ return loss_val, loss_terms
512
579
 
513
580
 
514
581
  class LossPDENonStatio(LossPDEStatio):
@@ -527,9 +594,9 @@ class LossPDENonStatio(LossPDEStatio):
527
594
 
528
595
  Parameters
529
596
  ----------
530
- u : eqx.Module
597
+ u : AbstractPINN
531
598
  the PINN
532
- dynamic_loss : DynamicLoss
599
+ dynamic_loss : PDENonStatio
533
600
  the non stationary PDE dynamic part of the loss, basically the differential
534
601
  operator $\mathcal{N}[u](t, x)$. Should implement a method
535
602
  `dynamic_loss.evaluate(t, x, u, params)`.
@@ -546,18 +613,21 @@ class LossPDENonStatio(LossPDEStatio):
546
613
  The loss weights for the differents term : dynamic loss,
547
614
  boundary conditions if any, initial condition, normalization loss if any and
548
615
  observations if any.
549
- All fields are set to 1.0 by default.
616
+ Can be updated according to a specific algorithm. See
617
+ `update_weight_method`
618
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
619
+ Default is None meaning no update for loss weights. Otherwise a string
550
620
  derivative_keys : DerivativeKeysPDENonStatio, default=None
551
621
  Specify which field of `params` should be differentiated for each
552
622
  composant of the total loss. Particularily useful for inverse problems.
553
623
  Fields can be "nn_params", "eq_params" or "both". Those that should not
554
624
  be updated will have a `jax.lax.stop_gradient` called on them. Default
555
625
  is `"nn_params"` for each composant of the loss.
556
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
626
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
557
627
  The function to be matched in the border condition (can be None) or a
558
628
  dictionary of such functions as values and keys as described
559
629
  in `omega_boundary_condition`.
560
- omega_boundary_condition : str | Dict[str, str], default=None
630
+ omega_boundary_condition : str | dict[str, str], default=None
561
631
  Either None (no condition, by default), or a string defining
562
632
  the boundary condition (Dirichlet or Von Neumann),
563
633
  or a dictionary with such strings as values. In this case,
@@ -568,17 +638,17 @@ class LossPDENonStatio(LossPDEStatio):
568
638
  a particular boundary condition on this facet.
569
639
  The facet called “xmin”, resp. “xmax” etc., in 2D,
570
640
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
571
- omega_boundary_dim : slice | Dict[str, slice], default=None
641
+ omega_boundary_dim : slice | dict[str, slice], default=None
572
642
  Either None, or a slice object or a dictionary of slice objects as
573
643
  values and keys as described in `omega_boundary_condition`.
574
644
  `omega_boundary_dim` indicates which dimension(s) of the PINN
575
645
  will be forced to match the boundary condition.
576
646
  Note that it must be a slice and not an integer
577
647
  (but a preprocessing of the user provided argument takes care of it)
578
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
648
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
579
649
  Monte-Carlo sample points for computing the
580
650
  normalization constant. Default is None.
581
- norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
651
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
582
652
  The importance sampling weights for Monte-Carlo integration of the
583
653
  normalization constant. Must be provided if `norm_samples` is provided.
584
654
  `norm_weights` should have the same leading dimension as
@@ -589,20 +659,25 @@ class LossPDENonStatio(LossPDEStatio):
589
659
  obs_slice : slice, default=None
590
660
  slice object specifying the begininning/ending of the PINN output
591
661
  that is observed (this is then useful for multidim PINN). Default is None.
662
+ t0 : float | Float[Array, " 1"], default=None
663
+ The time at which to apply the initial condition. If None, the time
664
+ is set to `0` by default.
592
665
  initial_condition_fun : Callable, default=None
593
- A function representing the temporal initial condition. If None
594
- (default) then no initial condition is applied
595
- params : InitVar[Params], default=None
596
- The main Params object of the problem needed to instanciate the
597
- DerivativeKeysODE if the latter is not specified.
666
+ A function representing the initial condition at `t0`. If None
667
+ (default) then no initial condition is applied.
668
+ params : InitVar[Params[Array]], default=None
669
+ The main `Params` object of the problem needed to instanciate the
670
+ `DerivativeKeysODE` if the latter is not specified.
598
671
 
599
672
  """
600
673
 
674
+ dynamic_loss: PDENonStatio | None
601
675
  # NOTE static=True only for leaf attributes that are not valid JAX types
602
676
  # (ie. jax.Array cannot be static) and that we do not expect to change
603
677
  initial_condition_fun: Callable | None = eqx.field(
604
678
  kw_only=True, default=None, static=True
605
679
  )
680
+ t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
606
681
 
607
682
  _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
608
683
  _max_norm_time_slices: Int = eqx.field(init=False, static=True)
@@ -624,6 +699,23 @@ class LossPDENonStatio(LossPDEStatio):
624
699
  "Initial condition wasn't provided. Be sure to cover for that"
625
700
  "case (e.g by. hardcoding it into the PINN output)."
626
701
  )
702
+ # some checks for t0
703
+ if isinstance(self.t0, Array):
704
+ if not self.t0.shape: # e.g. user input: jnp.array(0.)
705
+ self.t0 = jnp.array([self.t0])
706
+ elif self.t0.shape != (1,):
707
+ raise ValueError(
708
+ f"Wrong self.t0 input. It should be"
709
+ f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
710
+ )
711
+ elif isinstance(self.t0, float): # e.g. user input: 0.
712
+ self.t0 = jnp.array([self.t0])
713
+ elif isinstance(self.t0, int): # e.g. user input: 0
714
+ self.t0 = jnp.array([float(self.t0)])
715
+ elif self.t0 is None:
716
+ self.t0 = jnp.array([0])
717
+ else:
718
+ raise ValueError("Wrong value for t0")
627
719
 
628
720
  # witht the variables below we avoid memory overflow since a cartesian
629
721
  # product is taken
@@ -632,44 +724,50 @@ class LossPDENonStatio(LossPDEStatio):
632
724
 
633
725
  def _get_dynamic_loss_batch(
634
726
  self, batch: PDENonStatioBatch
635
- ) -> Float[Array, "batch_size 1+dimension"]:
727
+ ) -> Float[Array, " batch_size 1+dimension"]:
636
728
  return batch.domain_batch
637
729
 
638
730
  def _get_normalization_loss_batch(
639
731
  self, batch: PDENonStatioBatch
640
- ) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
732
+ ) -> tuple[
733
+ Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
734
+ ]:
641
735
  return (
642
736
  batch.domain_batch[: self._max_norm_time_slices, 0:1],
643
- self.norm_samples[: self._max_norm_samples_omega],
737
+ self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
644
738
  )
645
739
 
646
740
  def _get_observations_loss_batch(
647
741
  self, batch: PDENonStatioBatch
648
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
649
- return (batch.obs_batch_dict["pinn_in"],)
742
+ ) -> Float[Array, " batch_size 1+dim"]:
743
+ return batch.obs_batch_dict["pinn_in"]
650
744
 
651
745
  def __call__(self, *args, **kwargs):
652
746
  return self.evaluate(*args, **kwargs)
653
747
 
654
- def evaluate(
655
- self, params: Params, batch: PDENonStatioBatch
656
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
748
+ def evaluate_by_terms(
749
+ self, params: Params[Array], batch: PDENonStatioBatch
750
+ ) -> tuple[
751
+ PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
752
+ ]:
657
753
  """
658
754
  Evaluate the loss function at a batch of points for given parameters.
659
755
 
756
+ We retrieve two PyTrees with loss values and gradients for each term
660
757
 
661
758
  Parameters
662
759
  ---------
663
760
  params
664
761
  Parameters at which the loss is evaluated
665
762
  batch
666
- Composed of a batch of points in
667
- the domain, a batch of points in the domain
668
- border, a batch of time points and an optional additional batch
669
- of parameters (eg. for metamodeling) and an optional additional batch of observed
763
+ Composed of a batch of points in the
764
+ domain, a batch of points in the domain
765
+ border and an optional additional batch of parameters (eg. for
766
+ metamodeling) and an optional additional batch of observed
670
767
  inputs/outputs/parameters
671
768
  """
672
769
  omega_batch = batch.initial_batch
770
+ assert omega_batch is not None
673
771
 
674
772
  # Retrieve the optional eq_params_batch
675
773
  # and update eq_params with the latter
@@ -682,447 +780,62 @@ class LossPDENonStatio(LossPDEStatio):
682
780
 
683
781
  # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
684
782
  # mse_observation_loss we use the evaluate from parent class
685
- partial_mse, partial_mse_terms = super().evaluate(params, batch)
783
+ # As well as for their gradients
784
+ partial_mses, partial_grads = super().evaluate_by_terms(params, batch) # type: ignore
785
+ # ignore because batch is not PDEStatioBatch. We could use typing.cast though
686
786
 
687
787
  # initial condition
688
788
  if self.initial_condition_fun is not None:
689
- mse_initial_condition = initial_condition_apply(
789
+ mse_initial_condition_fun = lambda p: initial_condition_apply(
690
790
  self.u,
691
791
  omega_batch,
692
- _set_derivatives(params, self.derivative_keys.initial_condition),
792
+ _set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
693
793
  (0,) + vmap_in_axes_params,
694
- self.initial_condition_fun,
695
- self.loss_weights.initial_condition,
794
+ self.initial_condition_fun, # type: ignore
795
+ self.t0, # type: ignore can't get the narrowing in __post_init__
696
796
  )
697
- else:
698
- mse_initial_condition = jnp.array(0.0)
699
-
700
- # total loss
701
- total_loss = partial_mse + mse_initial_condition
702
-
703
- return total_loss, {
704
- **partial_mse_terms,
705
- "initial_condition": mse_initial_condition,
706
- }
707
-
708
-
709
- class SystemLossPDE(eqx.Module):
710
- r"""
711
- Class to implement a system of PDEs.
712
- The goal is to give maximum freedom to the user. The class is created with
713
- a dict of dynamic loss, and dictionaries of all the objects that are used
714
- in LossPDENonStatio and LossPDEStatio. When then iterate
715
- over the dynamic losses that compose the system. All the PINNs with all the
716
- parameter dictionaries are passed as arguments to each dynamic loss
717
- evaluate functions; it is inside the dynamic loss that specification are
718
- performed.
719
-
720
- **Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
721
- Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
722
- solution.
723
-
724
- Parameters
725
- ----------
726
- u_dict : Dict[str, eqx.Module]
727
- dict of PINNs
728
- loss_weights : LossWeightsPDEDict
729
- A dictionary of LossWeightsODE
730
- derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
731
- A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
732
- specifying what field of `params`
733
- should be used during gradient computations for each of the terms of
734
- the total loss, for each of the loss in the system. Default is
735
- `"nn_params`" everywhere.
736
- dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
737
- A dict of dynamic part of the loss, basically the differential
738
- operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
739
- key_dict : Dict[str, Key], default=None
740
- A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
741
- match that of u_dict. See LossPDEStatio or LossPDENonStatio for
742
- more details.
743
- omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
744
- A dict of of function or of dict of functions or of None
745
- (see doc for `omega_boundary_fun` in
746
- LossPDEStatio or LossPDENonStatio). Default is None.
747
- Must share the keys of `u_dict`.
748
- omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
749
- A dict of strings or of dict of strings or of None
750
- (see doc for `omega_boundary_condition_dict` in
751
- LossPDEStatio or LossPDENonStatio). Default is None.
752
- Must share the keys of `u_dict`
753
- omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
754
- A dict of slices or of dict of slices or of None
755
- (see doc for `omega_boundary_dim` in
756
- LossPDEStatio or LossPDENonStatio). Default is None.
757
- Must share the keys of `u_dict`
758
- initial_condition_fun_dict : Dict[str, Callable | None], default=None
759
- A dict of functions representing the temporal initial condition (None
760
- value is possible). If None
761
- (default) then no temporal boundary condition is applied
762
- Must share the keys of `u_dict`
763
- norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
764
- A dict of Monte-Carlo sample points for computing the
765
- normalization constant. Default is None.
766
- Must share the keys of `u_dict`
767
- norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
768
- A dict of jnp.array with the same keys as `u_dict`. The importance
769
- sampling weights for Monte-Carlo integration of the
770
- normalization constant for each element of u_dict. Must be provided if
771
- `norm_samples_dict` is provided.
772
- `norm_weights_dict[key]` should have the same leading dimension as
773
- `norm_samples_dict[key]` for each `key`.
774
- Alternatively, the user can pass a float or an integer.
775
- For each key, an array of similar shape to `norm_samples_dict[key]`
776
- or shape `(1,)` is expected. These corresponds to the weights $w_k =
777
- \frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
778
- the Monte-Carlo samples. Default is None
779
- obs_slice_dict : Dict[str, slice | None] | None, default=None
780
- dict of obs_slice, with keys from `u_dict` to designate the
781
- output(s) channels that are forced to observed values, for each
782
- PINNs. Default is None. But if a value is given, all the entries of
783
- `u_dict` must be represented here with default value `jnp.s_[...]`
784
- if no particular slice is to be given
785
- params : InitVar[ParamsDict], default=None
786
- The main Params object of the problem needed to instanciate the
787
- DerivativeKeysODE if the latter is not specified.
788
-
789
- """
790
-
791
- # NOTE static=True only for leaf attributes that are not valid JAX types
792
- # (ie. jax.Array cannot be static) and that we do not expect to change
793
- u_dict: Dict[str, eqx.Module]
794
- dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
795
- key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
796
- derivative_keys_dict: Dict[
797
- str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
798
- ] = eqx.field(kw_only=True, default=None)
799
- omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
800
- eqx.field(kw_only=True, default=None, static=True)
801
- )
802
- omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
803
- eqx.field(kw_only=True, default=None, static=True)
804
- )
805
- omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
806
- eqx.field(kw_only=True, default=None, static=True)
807
- )
808
- initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
809
- kw_only=True, default=None, static=True
810
- )
811
- norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
812
- eqx.field(kw_only=True, default=None)
813
- )
814
- norm_weights_dict: (
815
- Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
816
- ) = eqx.field(kw_only=True, default=None)
817
- obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
818
- kw_only=True, default=None, static=True
819
- )
820
-
821
- # For the user loss_weights are passed as a LossWeightsPDEDict (with internal
822
- # dictionary having keys in u_dict and / or dynamic_loss_dict)
823
- loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
824
- kw_only=True, default=None
825
- )
826
- params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
827
-
828
- # following have init=False and are set in the __post_init__
829
- u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
830
- init=False
831
- )
832
- derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
833
- eqx.field(init=False)
834
- )
835
- u_dict_with_none: Dict[str, None] = eqx.field(init=False)
836
- # internally the loss weights are handled with a dictionary
837
- _loss_weights: Dict[str, dict] = eqx.field(init=False)
838
-
839
- def __post_init__(self, loss_weights=None, params_dict=None):
840
- # a dictionary that will be useful at different places
841
- self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
842
- # First, for all the optional dict,
843
- # if the user did not provide at all this optional argument,
844
- # we make sure there is a null ponderating loss_weight and we
845
- # create a dummy dict with the required keys and all the values to
846
- # None
847
- if self.key_dict is None:
848
- self.key_dict = self.u_dict_with_none
849
- if self.omega_boundary_fun_dict is None:
850
- self.omega_boundary_fun_dict = self.u_dict_with_none
851
- if self.omega_boundary_condition_dict is None:
852
- self.omega_boundary_condition_dict = self.u_dict_with_none
853
- if self.omega_boundary_dim_dict is None:
854
- self.omega_boundary_dim_dict = self.u_dict_with_none
855
- if self.initial_condition_fun_dict is None:
856
- self.initial_condition_fun_dict = self.u_dict_with_none
857
- if self.norm_samples_dict is None:
858
- self.norm_samples_dict = self.u_dict_with_none
859
- if self.norm_weights_dict is None:
860
- self.norm_weights_dict = self.u_dict_with_none
861
- if self.obs_slice_dict is None:
862
- self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
863
- if self.u_dict.keys() != self.obs_slice_dict.keys():
864
- raise ValueError("obs_slice_dict should have same keys as u_dict")
865
- if self.derivative_keys_dict is None:
866
- self.derivative_keys_dict = {
867
- k: None
868
- for k in set(
869
- list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
870
- )
871
- }
872
- # set() because we can have duplicate entries and in this case we
873
- # say it corresponds to the same derivative_keys_dict entry
874
- # we need both because the constraints (all but dyn_loss) will be
875
- # done by iterating on u_dict while the dyn_loss will be by
876
- # iterating on dynamic_loss_dict. So each time we will require dome
877
- # derivative_keys_dict
878
-
879
- # derivative keys for the u_constraints. Note that we create missing
880
- # DerivativeKeysODE around a Params object and not ParamsDict
881
- # this works because u_dict.keys == params_dict.nn_params.keys()
882
- for k in self.u_dict.keys():
883
- if self.derivative_keys_dict[k] is None:
884
- if self.u_dict[k].eq_type == "statio_PDE":
885
- self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
886
- params=params_dict.extract_params(k)
887
- )
888
- else:
889
- self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
890
- params=params_dict.extract_params(k)
891
- )
892
-
893
- # Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
894
- if (
895
- self.u_dict.keys() != self.key_dict.keys()
896
- or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
897
- or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
898
- or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
899
- or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
900
- or self.u_dict.keys() != self.norm_samples_dict.keys()
901
- or self.u_dict.keys() != self.norm_weights_dict.keys()
902
- ):
903
- raise ValueError("All the dicts concerning the PINNs should have same keys")
904
-
905
- self._loss_weights = self.set_loss_weights(loss_weights)
906
-
907
- # Third, in order not to benefit from LossPDEStatio and
908
- # LossPDENonStatio and in order to factorize code, we create internally
909
- # some losses object to implement the constraints on the solutions.
910
- # We will not use the dynamic loss term
911
- self.u_constraints_dict = {}
912
- for i in self.u_dict.keys():
913
- if self.u_dict[i].eq_type == "statio_PDE":
914
- self.u_constraints_dict[i] = LossPDEStatio(
915
- u=self.u_dict[i],
916
- loss_weights=LossWeightsPDENonStatio(
917
- dyn_loss=0.0,
918
- norm_loss=1.0,
919
- boundary_loss=1.0,
920
- observations=1.0,
921
- initial_condition=1.0,
922
- ),
923
- dynamic_loss=None,
924
- key=self.key_dict[i],
925
- derivative_keys=self.derivative_keys_dict[i],
926
- omega_boundary_fun=self.omega_boundary_fun_dict[i],
927
- omega_boundary_condition=self.omega_boundary_condition_dict[i],
928
- omega_boundary_dim=self.omega_boundary_dim_dict[i],
929
- norm_samples=self.norm_samples_dict[i],
930
- norm_weights=self.norm_weights_dict[i],
931
- obs_slice=self.obs_slice_dict[i],
932
- )
933
- elif self.u_dict[i].eq_type == "nonstatio_PDE":
934
- self.u_constraints_dict[i] = LossPDENonStatio(
935
- u=self.u_dict[i],
936
- loss_weights=LossWeightsPDENonStatio(
937
- dyn_loss=0.0,
938
- norm_loss=1.0,
939
- boundary_loss=1.0,
940
- observations=1.0,
941
- initial_condition=1.0,
942
- ),
943
- dynamic_loss=None,
944
- key=self.key_dict[i],
945
- derivative_keys=self.derivative_keys_dict[i],
946
- omega_boundary_fun=self.omega_boundary_fun_dict[i],
947
- omega_boundary_condition=self.omega_boundary_condition_dict[i],
948
- omega_boundary_dim=self.omega_boundary_dim_dict[i],
949
- initial_condition_fun=self.initial_condition_fun_dict[i],
950
- norm_samples=self.norm_samples_dict[i],
951
- norm_weights=self.norm_weights_dict[i],
952
- obs_slice=self.obs_slice_dict[i],
953
- )
954
- else:
955
- raise ValueError(
956
- "Wrong value for self.u_dict[i].eq_type[i], "
957
- f"got {self.u_dict[i].eq_type[i]}"
958
- )
959
-
960
- # derivative keys for the dynamic loss. Note that we create a
961
- # DerivativeKeysODE around a ParamsDict object because a whole
962
- # params_dict is feed to DynamicLoss.evaluate functions (extract_params
963
- # happen inside it)
964
- self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
965
-
966
- # also make sure we only have PINNs or SPINNs
967
- if not (
968
- all(isinstance(value, PINN) for value in self.u_dict.values())
969
- or all(isinstance(value, SPINN) for value in self.u_dict.values())
970
- ):
971
- raise ValueError(
972
- "We only accept dictionary of PINNs or dictionary of SPINNs"
797
+ mse_initial_condition, grad_initial_condition = self.get_gradients(
798
+ mse_initial_condition_fun, params
973
799
  )
800
+ else:
801
+ mse_initial_condition = None
802
+ grad_initial_condition = None
803
+
804
+ mses = PDENonStatioComponents(
805
+ partial_mses.dyn_loss,
806
+ partial_mses.norm_loss,
807
+ partial_mses.boundary_loss,
808
+ partial_mses.observations,
809
+ mse_initial_condition,
810
+ )
974
811
 
975
- def set_loss_weights(
976
- self, loss_weights_init: LossWeightsPDEDict
977
- ) -> dict[str, dict]:
978
- """
979
- This rather complex function enables the user to specify a simple
980
- loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
981
- for ponderating values being applied to all the equations of the
982
- system... So all the transformations are handled here
983
- """
984
- _loss_weights = {}
985
- for k in fields(loss_weights_init):
986
- v = getattr(loss_weights_init, k.name)
987
- if isinstance(v, dict):
988
- for vv in v.keys():
989
- if not isinstance(vv, (int, float)) and not (
990
- isinstance(vv, Array)
991
- and ((vv.shape == (1,) or len(vv.shape) == 0))
992
- ):
993
- # TODO improve that
994
- raise ValueError(
995
- f"loss values cannot be vectorial here, got {vv}"
996
- )
997
- if k.name == "dyn_loss":
998
- if v.keys() == self.dynamic_loss_dict.keys():
999
- _loss_weights[k.name] = v
1000
- else:
1001
- raise ValueError(
1002
- "Keys in nested dictionary of loss_weights"
1003
- " do not match dynamic_loss_dict keys"
1004
- )
1005
- else:
1006
- if v.keys() == self.u_dict.keys():
1007
- _loss_weights[k.name] = v
1008
- else:
1009
- raise ValueError(
1010
- "Keys in nested dictionary of loss_weights"
1011
- " do not match u_dict keys"
1012
- )
1013
- if v is None:
1014
- _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
1015
- else:
1016
- if not isinstance(v, (int, float)) and not (
1017
- isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
1018
- ):
1019
- # TODO improve that
1020
- raise ValueError(f"loss values cannot be vectorial here, got {v}")
1021
- if k.name == "dyn_loss":
1022
- _loss_weights[k.name] = {
1023
- kk: v for kk in self.dynamic_loss_dict.keys()
1024
- }
1025
- else:
1026
- _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
1027
- return _loss_weights
812
+ grads = PDENonStatioComponents(
813
+ partial_grads.dyn_loss,
814
+ partial_grads.norm_loss,
815
+ partial_grads.boundary_loss,
816
+ partial_grads.observations,
817
+ grad_initial_condition,
818
+ )
1028
819
 
1029
- def __call__(self, *args, **kwargs):
1030
- return self.evaluate(*args, **kwargs)
820
+ return mses, grads
1031
821
 
1032
822
  def evaluate(
1033
- self,
1034
- params_dict: ParamsDict,
1035
- batch: PDEStatioBatch | PDENonStatioBatch,
1036
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
823
+ self, params: Params[Array], batch: PDENonStatioBatch
824
+ ) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
1037
825
  """
1038
826
  Evaluate the loss function at a batch of points for given parameters.
827
+ We retrieve the total value itself and a PyTree with loss values for each term
1039
828
 
1040
829
 
1041
830
  Parameters
1042
831
  ---------
1043
- params_dict
1044
- Parameters at which the losses of the system are evaluated
832
+ params
833
+ Parameters at which the loss is evaluated
1045
834
  batch
1046
- Such named tuples are composed of batch of points in the
1047
- domain, a batch of points in the domain
1048
- border, (a batch of time points a for PDENonStatioBatch) and an
1049
- optional additional batch of parameters (eg. for metamodeling)
1050
- and an optional additional batch of observed
835
+ Composed of a batch of points in
836
+ the domain, a batch of points in the domain
837
+ border, a batch of time points and an optional additional batch
838
+ of parameters (eg. for metamodeling) and an optional additional batch of observed
1051
839
  inputs/outputs/parameters
1052
840
  """
1053
- if self.u_dict.keys() != params_dict.nn_params.keys():
1054
- raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1055
-
1056
- vmap_in_axes = (0,)
1057
-
1058
- # Retrieve the optional eq_params_batch
1059
- # and update eq_params with the latter
1060
- # and update vmap_in_axes
1061
- if batch.param_batch_dict is not None:
1062
- eq_params_batch_dict = batch.param_batch_dict
1063
-
1064
- # feed the eq_params with the batch
1065
- for k in eq_params_batch_dict.keys():
1066
- params_dict.eq_params[k] = eq_params_batch_dict[k]
1067
-
1068
- vmap_in_axes_params = _get_vmap_in_axes_params(
1069
- batch.param_batch_dict, params_dict
1070
- )
1071
-
1072
- def dyn_loss_for_one_key(dyn_loss, loss_weight):
1073
- """The function used in tree_map"""
1074
- return dynamic_loss_apply(
1075
- dyn_loss.evaluate,
1076
- self.u_dict,
1077
- (
1078
- batch.domain_batch
1079
- if isinstance(batch, PDEStatioBatch)
1080
- else batch.domain_batch
1081
- ),
1082
- _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1083
- vmap_in_axes + vmap_in_axes_params,
1084
- loss_weight,
1085
- u_type=list(self.u_dict.values())[0].__class__.__base__,
1086
- )
1087
-
1088
- dyn_loss_mse_dict = jax.tree_util.tree_map(
1089
- dyn_loss_for_one_key,
1090
- self.dynamic_loss_dict,
1091
- self._loss_weights["dyn_loss"],
1092
- is_leaf=lambda x: isinstance(
1093
- x, (PDEStatio, PDENonStatio)
1094
- ), # before when dynamic losses
1095
- # where plain (unregister pytree) node classes, we could not traverse
1096
- # this level. Now that dynamic losses are eqx.Module they can be
1097
- # traversed by tree map recursion. Hence we need to specify to that
1098
- # we want to stop at this level
1099
- )
1100
- mse_dyn_loss = jax.tree_util.tree_reduce(
1101
- lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
1102
- )
1103
-
1104
- # boundary conditions, normalization conditions, observation_loss,
1105
- # initial condition... loss this is done via the internal
1106
- # LossPDEStatio and NonStatio
1107
- loss_weight_struct = {
1108
- "dyn_loss": "*",
1109
- "norm_loss": "*",
1110
- "boundary_loss": "*",
1111
- "observations": "*",
1112
- "initial_condition": "*",
1113
- }
1114
- # we need to do the following for the tree_mapping to work
1115
- if batch.obs_batch_dict is None:
1116
- batch = append_obs_batch(batch, self.u_dict_with_none)
1117
- total_loss, res_dict = constraints_system_loss_apply(
1118
- self.u_constraints_dict,
1119
- batch,
1120
- params_dict,
1121
- self._loss_weights,
1122
- loss_weight_struct,
1123
- )
1124
-
1125
- # Add the mse_dyn_loss from the previous computations
1126
- total_loss += mse_dyn_loss
1127
- res_dict["dyn_loss"] += mse_dyn_loss
1128
- return total_loss, res_dict
841
+ return super().evaluate(params, batch) # type: ignore