jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py CHANGED
@@ -1,16 +1,16 @@
1
- # pylint: disable=unsubscriptable-object, no-member
2
1
  """
3
2
  Main module to implement a PDE loss in jinns
4
3
  """
4
+
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  import abc
10
- from dataclasses import InitVar, fields
11
- from typing import TYPE_CHECKING, Dict, Callable
10
+ from dataclasses import InitVar
11
+ from typing import TYPE_CHECKING, Callable, TypedDict
12
+ from types import EllipsisType
12
13
  import warnings
13
- import jax
14
14
  import jax.numpy as jnp
15
15
  import equinox as eqx
16
16
  from jaxtyping import Float, Array, Key, Int
@@ -20,9 +20,7 @@ from jinns.loss._loss_utils import (
20
20
  normalization_loss_apply,
21
21
  observations_loss_apply,
22
22
  initial_condition_apply,
23
- constraints_system_loss_apply,
24
23
  )
25
- from jinns.data._DataGenerators import append_obs_batch
26
24
  from jinns.parameters._params import (
27
25
  _get_vmap_in_axes_params,
28
26
  _update_eq_params_dict,
@@ -32,19 +30,30 @@ from jinns.parameters._derivative_keys import (
32
30
  DerivativeKeysPDEStatio,
33
31
  DerivativeKeysPDENonStatio,
34
32
  )
33
+ from jinns.loss._abstract_loss import AbstractLoss
35
34
  from jinns.loss._loss_weights import (
36
35
  LossWeightsPDEStatio,
37
36
  LossWeightsPDENonStatio,
38
- LossWeightsPDEDict,
39
37
  )
40
- from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
41
- from jinns.nn._pinn import PINN
42
- from jinns.nn._spinn import SPINN
43
38
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
44
39
 
45
40
 
46
41
  if TYPE_CHECKING:
47
- from jinns.utils._types import *
42
+ # imports for type hints only
43
+ from jinns.parameters._params import Params
44
+ from jinns.nn._abstract_pinn import AbstractPINN
45
+ from jinns.loss import PDENonStatio, PDEStatio
46
+ from jinns.utils._types import BoundaryConditionFun
47
+
48
+ class LossDictPDEStatio(TypedDict):
49
+ dyn_loss: Float[Array, " "]
50
+ norm_loss: Float[Array, " "]
51
+ boundary_loss: Float[Array, " "]
52
+ observations: Float[Array, " "]
53
+
54
+ class LossDictPDENonStatio(LossDictPDEStatio):
55
+ initial_condition: Float[Array, " "]
56
+
48
57
 
49
58
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
50
59
  "dirichlet",
@@ -53,8 +62,8 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
53
62
  ]
54
63
 
55
64
 
56
- class _LossPDEAbstract(eqx.Module):
57
- """
65
+ class _LossPDEAbstract(AbstractLoss):
66
+ r"""
58
67
  Parameters
59
68
  ----------
60
69
 
@@ -69,11 +78,11 @@ class _LossPDEAbstract(eqx.Module):
69
78
  Fields can be "nn_params", "eq_params" or "both". Those that should not
70
79
  be updated will have a `jax.lax.stop_gradient` called on them. Default
71
80
  is `"nn_params"` for each composant of the loss.
72
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
81
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
73
82
  The function to be matched in the border condition (can be None) or a
74
83
  dictionary of such functions as values and keys as described
75
84
  in `omega_boundary_condition`.
76
- omega_boundary_condition : str | Dict[str, str], default=None
85
+ omega_boundary_condition : str | dict[str, str], default=None
77
86
  Either None (no condition, by default), or a string defining
78
87
  the boundary condition (Dirichlet or Von Neumann),
79
88
  or a dictionary with such strings as values. In this case,
@@ -84,28 +93,29 @@ class _LossPDEAbstract(eqx.Module):
84
93
  a particular boundary condition on this facet.
85
94
  The facet called “xmin”, resp. “xmax” etc., in 2D,
86
95
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
87
- omega_boundary_dim : slice | Dict[str, slice], default=None
96
+ omega_boundary_dim : slice | dict[str, slice], default=None
88
97
  Either None, or a slice object or a dictionary of slice objects as
89
98
  values and keys as described in `omega_boundary_condition`.
90
99
  `omega_boundary_dim` indicates which dimension(s) of the PINN
91
100
  will be forced to match the boundary condition.
92
101
  Note that it must be a slice and not an integer
93
102
  (but a preprocessing of the user provided argument takes care of it)
94
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
103
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
95
104
  Monte-Carlo sample points for computing the
96
105
  normalization constant. Default is None.
97
- norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
106
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
98
107
  The importance sampling weights for Monte-Carlo integration of the
99
108
  normalization constant. Must be provided if `norm_samples` is provided.
100
- `norm_weights` should have the same leading dimension as
109
+ `norm_weights` should be broadcastble to
101
110
  `norm_samples`.
102
- Alternatively, the user can pass a float or an integer.
111
+ Alternatively, the user can pass a float or an integer that will be
112
+ made broadcastable to `norm_samples`.
103
113
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
104
114
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
105
- obs_slice : slice, default=None
115
+ obs_slice : EllipsisType | slice, default=None
106
116
  slice object specifying the begininning/ending of the PINN output
107
117
  that is observed (this is then useful for multidim PINN). Default is None.
108
- params : InitVar[Params], default=None
118
+ params : InitVar[Params[Array]], default=None
109
119
  The main Params object of the problem needed to instanciate the
110
120
  DerivativeKeysODE if the latter is not specified.
111
121
  """
@@ -119,26 +129,28 @@ class _LossPDEAbstract(eqx.Module):
119
129
  loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
120
130
  kw_only=True, default=None
121
131
  )
122
- omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
123
- kw_only=True, default=None, static=True
124
- )
125
- omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
132
+ omega_boundary_fun: (
133
+ BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
134
+ ) = eqx.field(kw_only=True, default=None, static=True)
135
+ omega_boundary_condition: str | dict[str, str] | None = eqx.field(
126
136
  kw_only=True, default=None, static=True
127
137
  )
128
- omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
138
+ omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
129
139
  kw_only=True, default=None, static=True
130
140
  )
131
- norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
141
+ norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
132
142
  kw_only=True, default=None
133
143
  )
134
- norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
144
+ norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
135
145
  kw_only=True, default=None
136
146
  )
137
- obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
147
+ obs_slice: EllipsisType | slice | None = eqx.field(
148
+ kw_only=True, default=None, static=True
149
+ )
138
150
 
139
- params: InitVar[Params] = eqx.field(kw_only=True, default=None)
151
+ params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
140
152
 
141
- def __post_init__(self, params=None):
153
+ def __post_init__(self, params: Params[Array] | None = None):
142
154
  """
143
155
  Note that neither __init__ or __post_init__ are called when udating a
144
156
  Module with eqx.tree_at
@@ -228,6 +240,11 @@ class _LossPDEAbstract(eqx.Module):
228
240
  )
229
241
 
230
242
  if isinstance(self.omega_boundary_fun, dict):
243
+ if not isinstance(self.omega_boundary_dim, dict):
244
+ raise ValueError(
245
+ "If omega_boundary_fun is a dict then"
246
+ " omega_boundary_dim should also be a dict"
247
+ )
231
248
  if self.omega_boundary_dim is None:
232
249
  self.omega_boundary_dim = {
233
250
  k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
@@ -262,27 +279,29 @@ class _LossPDEAbstract(eqx.Module):
262
279
  raise ValueError(
263
280
  "`norm_weights` must be provided when `norm_samples` is used!"
264
281
  )
265
- try:
266
- assert self.norm_weights.shape[0] == self.norm_samples.shape[0]
267
- except (AssertionError, AttributeError):
268
- if isinstance(self.norm_weights, (int, float)):
269
- self.norm_weights = jnp.array(
270
- [self.norm_weights], dtype=jax.dtypes.canonicalize_dtype(float)
271
- )
272
- else:
282
+ if isinstance(self.norm_weights, (int, float)):
283
+ self.norm_weights = self.norm_weights * jnp.ones(
284
+ (self.norm_samples.shape[0],)
285
+ )
286
+ if isinstance(self.norm_weights, Array):
287
+ if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
273
288
  raise ValueError(
274
- "`norm_weights` should have the same leading dimension"
275
- " as `norm_samples`,"
276
- f" got shape {self.norm_weights.shape} and"
277
- f" shape {self.norm_samples.shape}."
289
+ "self.norm_weights and "
290
+ "self.norm_samples must have the same leading dimension"
278
291
  )
292
+ else:
293
+ raise ValueError("Wrong type for self.norm_weights")
294
+
295
+ @abc.abstractmethod
296
+ def __call__(self, *_, **__):
297
+ pass
279
298
 
280
299
  @abc.abstractmethod
281
300
  def evaluate(
282
301
  self: eqx.Module,
283
- params: Params,
302
+ params: Params[Array],
284
303
  batch: PDEStatioBatch | PDENonStatioBatch,
285
- ) -> tuple[Float, dict]:
304
+ ) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
286
305
  raise NotImplementedError
287
306
 
288
307
 
@@ -299,9 +318,9 @@ class LossPDEStatio(_LossPDEAbstract):
299
318
 
300
319
  Parameters
301
320
  ----------
302
- u : eqx.Module
321
+ u : AbstractPINN
303
322
  the PINN
304
- dynamic_loss : DynamicLoss
323
+ dynamic_loss : PDEStatio
305
324
  the stationary PDE dynamic part of the loss, basically the differential
306
325
  operator $\mathcal{N}[u](x)$. Should implement a method
307
326
  `dynamic_loss.evaluate(x, u, params)`.
@@ -324,11 +343,11 @@ class LossPDEStatio(_LossPDEAbstract):
324
343
  Fields can be "nn_params", "eq_params" or "both". Those that should not
325
344
  be updated will have a `jax.lax.stop_gradient` called on them. Default
326
345
  is `"nn_params"` for each composant of the loss.
327
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
346
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
328
347
  The function to be matched in the border condition (can be None) or a
329
348
  dictionary of such functions as values and keys as described
330
349
  in `omega_boundary_condition`.
331
- omega_boundary_condition : str | Dict[str, str], default=None
350
+ omega_boundary_condition : str | dict[str, str], default=None
332
351
  Either None (no condition, by default), or a string defining
333
352
  the boundary condition (Dirichlet or Von Neumann),
334
353
  or a dictionary with such strings as values. In this case,
@@ -339,17 +358,17 @@ class LossPDEStatio(_LossPDEAbstract):
339
358
  a particular boundary condition on this facet.
340
359
  The facet called “xmin”, resp. “xmax” etc., in 2D,
341
360
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
342
- omega_boundary_dim : slice | Dict[str, slice], default=None
361
+ omega_boundary_dim : slice | dict[str, slice], default=None
343
362
  Either None, or a slice object or a dictionary of slice objects as
344
363
  values and keys as described in `omega_boundary_condition`.
345
364
  `omega_boundary_dim` indicates which dimension(s) of the PINN
346
365
  will be forced to match the boundary condition.
347
366
  Note that it must be a slice and not an integer
348
367
  (but a preprocessing of the user provided argument takes care of it)
349
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
368
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
350
369
  Monte-Carlo sample points for computing the
351
370
  normalization constant. Default is None.
352
- norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
371
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
353
372
  The importance sampling weights for Monte-Carlo integration of the
354
373
  normalization constant. Must be provided if `norm_samples` is provided.
355
374
  `norm_weights` should have the same leading dimension as
@@ -360,7 +379,7 @@ class LossPDEStatio(_LossPDEAbstract):
360
379
  obs_slice : slice, default=None
361
380
  slice object specifying the begininning/ending of the PINN output
362
381
  that is observed (this is then useful for multidim PINN). Default is None.
363
- params : InitVar[Params], default=None
382
+ params : InitVar[Params[Array]], default=None
364
383
  The main Params object of the problem needed to instanciate the
365
384
  DerivativeKeysODE if the latter is not specified.
366
385
 
@@ -375,13 +394,13 @@ class LossPDEStatio(_LossPDEAbstract):
375
394
  # NOTE static=True only for leaf attributes that are not valid JAX types
376
395
  # (ie. jax.Array cannot be static) and that we do not expect to change
377
396
 
378
- u: eqx.Module
379
- dynamic_loss: DynamicLoss | None
397
+ u: AbstractPINN
398
+ dynamic_loss: PDEStatio | None
380
399
  key: Key | None = eqx.field(kw_only=True, default=None)
381
400
 
382
401
  vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
383
402
 
384
- def __post_init__(self, params=None):
403
+ def __post_init__(self, params: Params[Array] | None = None):
385
404
  """
386
405
  Note that neither __init__ or __post_init__ are called when udating a
387
406
  Module with eqx.tree_at!
@@ -395,25 +414,27 @@ class LossPDEStatio(_LossPDEAbstract):
395
414
 
396
415
  def _get_dynamic_loss_batch(
397
416
  self, batch: PDEStatioBatch
398
- ) -> Float[Array, "batch_size dimension"]:
417
+ ) -> Float[Array, " batch_size dimension"]:
399
418
  return batch.domain_batch
400
419
 
401
420
  def _get_normalization_loss_batch(
402
421
  self, _
403
- ) -> Float[Array, "nb_norm_samples dimension"]:
404
- return (self.norm_samples,)
422
+ ) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
423
+ return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
424
+
425
+ # we could have used typing.cast though
405
426
 
406
427
  def _get_observations_loss_batch(
407
428
  self, batch: PDEStatioBatch
408
- ) -> Float[Array, "batch_size obs_dim"]:
409
- return (batch.obs_batch_dict["pinn_in"],)
429
+ ) -> Float[Array, " batch_size obs_dim"]:
430
+ return batch.obs_batch_dict["pinn_in"]
410
431
 
411
432
  def __call__(self, *args, **kwargs):
412
433
  return self.evaluate(*args, **kwargs)
413
434
 
414
435
  def evaluate(
415
- self, params: Params, batch: PDEStatioBatch
416
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
436
+ self, params: Params[Array], batch: PDEStatioBatch
437
+ ) -> tuple[Float[Array, " "], LossDictPDEStatio]:
417
438
  """
418
439
  Evaluate the loss function at a batch of points for given parameters.
419
440
 
@@ -444,9 +465,9 @@ class LossPDEStatio(_LossPDEAbstract):
444
465
  self.dynamic_loss.evaluate,
445
466
  self.u,
446
467
  self._get_dynamic_loss_batch(batch),
447
- _set_derivatives(params, self.derivative_keys.dyn_loss),
468
+ _set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
448
469
  self.vmap_in_axes + vmap_in_axes_params,
449
- self.loss_weights.dyn_loss,
470
+ self.loss_weights.dyn_loss, # type: ignore
450
471
  )
451
472
  else:
452
473
  mse_dyn_loss = jnp.array(0.0)
@@ -456,24 +477,28 @@ class LossPDEStatio(_LossPDEAbstract):
456
477
  mse_norm_loss = normalization_loss_apply(
457
478
  self.u,
458
479
  self._get_normalization_loss_batch(batch),
459
- _set_derivatives(params, self.derivative_keys.norm_loss),
480
+ _set_derivatives(params, self.derivative_keys.norm_loss), # type: ignore
460
481
  vmap_in_axes_params,
461
- self.norm_weights,
462
- self.loss_weights.norm_loss,
482
+ self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
483
+ self.loss_weights.norm_loss, # type: ignore
463
484
  )
464
485
  else:
465
486
  mse_norm_loss = jnp.array(0.0)
466
487
 
467
488
  # boundary part
468
- if self.omega_boundary_condition is not None:
489
+ if (
490
+ self.omega_boundary_condition is not None
491
+ and self.omega_boundary_dim is not None
492
+ and self.omega_boundary_fun is not None
493
+ ): # pyright cannot narrow down the three None otherwise as it is class attribute
469
494
  mse_boundary_loss = boundary_condition_apply(
470
495
  self.u,
471
496
  batch,
472
- _set_derivatives(params, self.derivative_keys.boundary_loss),
497
+ _set_derivatives(params, self.derivative_keys.boundary_loss), # type: ignore
473
498
  self.omega_boundary_fun,
474
499
  self.omega_boundary_condition,
475
500
  self.omega_boundary_dim,
476
- self.loss_weights.boundary_loss,
501
+ self.loss_weights.boundary_loss, # type: ignore
477
502
  )
478
503
  else:
479
504
  mse_boundary_loss = jnp.array(0.0)
@@ -486,10 +511,10 @@ class LossPDEStatio(_LossPDEAbstract):
486
511
  mse_observation_loss = observations_loss_apply(
487
512
  self.u,
488
513
  self._get_observations_loss_batch(batch),
489
- _set_derivatives(params, self.derivative_keys.observations),
514
+ _set_derivatives(params, self.derivative_keys.observations), # type: ignore
490
515
  self.vmap_in_axes + vmap_in_axes_params,
491
516
  batch.obs_batch_dict["val"],
492
- self.loss_weights.observations,
517
+ self.loss_weights.observations, # type: ignore
493
518
  self.obs_slice,
494
519
  )
495
520
  else:
@@ -505,8 +530,6 @@ class LossPDEStatio(_LossPDEAbstract):
505
530
  "norm_loss": mse_norm_loss,
506
531
  "boundary_loss": mse_boundary_loss,
507
532
  "observations": mse_observation_loss,
508
- "initial_condition": jnp.array(0.0), # for compatibility in the
509
- # tree_map of SystemLoss
510
533
  }
511
534
  )
512
535
 
@@ -527,9 +550,9 @@ class LossPDENonStatio(LossPDEStatio):
527
550
 
528
551
  Parameters
529
552
  ----------
530
- u : eqx.Module
553
+ u : AbstractPINN
531
554
  the PINN
532
- dynamic_loss : DynamicLoss
555
+ dynamic_loss : PDENonStatio
533
556
  the non stationary PDE dynamic part of the loss, basically the differential
534
557
  operator $\mathcal{N}[u](t, x)$. Should implement a method
535
558
  `dynamic_loss.evaluate(t, x, u, params)`.
@@ -553,11 +576,11 @@ class LossPDENonStatio(LossPDEStatio):
553
576
  Fields can be "nn_params", "eq_params" or "both". Those that should not
554
577
  be updated will have a `jax.lax.stop_gradient` called on them. Default
555
578
  is `"nn_params"` for each composant of the loss.
556
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
579
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
557
580
  The function to be matched in the border condition (can be None) or a
558
581
  dictionary of such functions as values and keys as described
559
582
  in `omega_boundary_condition`.
560
- omega_boundary_condition : str | Dict[str, str], default=None
583
+ omega_boundary_condition : str | dict[str, str], default=None
561
584
  Either None (no condition, by default), or a string defining
562
585
  the boundary condition (Dirichlet or Von Neumann),
563
586
  or a dictionary with such strings as values. In this case,
@@ -568,17 +591,17 @@ class LossPDENonStatio(LossPDEStatio):
568
591
  a particular boundary condition on this facet.
569
592
  The facet called “xmin”, resp. “xmax” etc., in 2D,
570
593
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
571
- omega_boundary_dim : slice | Dict[str, slice], default=None
594
+ omega_boundary_dim : slice | dict[str, slice], default=None
572
595
  Either None, or a slice object or a dictionary of slice objects as
573
596
  values and keys as described in `omega_boundary_condition`.
574
597
  `omega_boundary_dim` indicates which dimension(s) of the PINN
575
598
  will be forced to match the boundary condition.
576
599
  Note that it must be a slice and not an integer
577
600
  (but a preprocessing of the user provided argument takes care of it)
578
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
601
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
579
602
  Monte-Carlo sample points for computing the
580
603
  normalization constant. Default is None.
581
- norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
604
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
582
605
  The importance sampling weights for Monte-Carlo integration of the
583
606
  normalization constant. Must be provided if `norm_samples` is provided.
584
607
  `norm_weights` should have the same leading dimension as
@@ -589,20 +612,25 @@ class LossPDENonStatio(LossPDEStatio):
589
612
  obs_slice : slice, default=None
590
613
  slice object specifying the begininning/ending of the PINN output
591
614
  that is observed (this is then useful for multidim PINN). Default is None.
615
+ t0 : float | Float[Array, " 1"], default=None
616
+ The time at which to apply the initial condition. If None, the time
617
+ is set to `0` by default.
592
618
  initial_condition_fun : Callable, default=None
593
- A function representing the temporal initial condition. If None
594
- (default) then no initial condition is applied
595
- params : InitVar[Params], default=None
596
- The main Params object of the problem needed to instanciate the
597
- DerivativeKeysODE if the latter is not specified.
619
+ A function representing the initial condition at `t0`. If None
620
+ (default) then no initial condition is applied.
621
+ params : InitVar[Params[Array]], default=None
622
+ The main `Params` object of the problem needed to instanciate the
623
+ `DerivativeKeysODE` if the latter is not specified.
598
624
 
599
625
  """
600
626
 
627
+ dynamic_loss: PDENonStatio | None
601
628
  # NOTE static=True only for leaf attributes that are not valid JAX types
602
629
  # (ie. jax.Array cannot be static) and that we do not expect to change
603
630
  initial_condition_fun: Callable | None = eqx.field(
604
631
  kw_only=True, default=None, static=True
605
632
  )
633
+ t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
606
634
 
607
635
  _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
608
636
  _max_norm_time_slices: Int = eqx.field(init=False, static=True)
@@ -624,6 +652,21 @@ class LossPDENonStatio(LossPDEStatio):
624
652
  "Initial condition wasn't provided. Be sure to cover for that"
625
653
  "case (e.g by. hardcoding it into the PINN output)."
626
654
  )
655
+ # some checks for t0
656
+ if isinstance(self.t0, Array):
657
+ if not self.t0.shape: # e.g. user input: jnp.array(0.)
658
+ self.t0 = jnp.array([self.t0])
659
+ elif self.t0.shape != (1,):
660
+ raise ValueError(
661
+ f"Wrong self.t0 input. It should be"
662
+ f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
663
+ )
664
+ elif isinstance(self.t0, float): # e.g. user input: 0
665
+ self.t0 = jnp.array([self.t0])
666
+ elif self.t0 is None:
667
+ self.t0 = jnp.array([0])
668
+ else:
669
+ raise ValueError("Wrong value for t0")
627
670
 
628
671
  # witht the variables below we avoid memory overflow since a cartesian
629
672
  # product is taken
@@ -632,28 +675,30 @@ class LossPDENonStatio(LossPDEStatio):
632
675
 
633
676
  def _get_dynamic_loss_batch(
634
677
  self, batch: PDENonStatioBatch
635
- ) -> Float[Array, "batch_size 1+dimension"]:
678
+ ) -> Float[Array, " batch_size 1+dimension"]:
636
679
  return batch.domain_batch
637
680
 
638
681
  def _get_normalization_loss_batch(
639
682
  self, batch: PDENonStatioBatch
640
- ) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
683
+ ) -> tuple[
684
+ Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
685
+ ]:
641
686
  return (
642
687
  batch.domain_batch[: self._max_norm_time_slices, 0:1],
643
- self.norm_samples[: self._max_norm_samples_omega],
688
+ self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
644
689
  )
645
690
 
646
691
  def _get_observations_loss_batch(
647
692
  self, batch: PDENonStatioBatch
648
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
649
- return (batch.obs_batch_dict["pinn_in"],)
693
+ ) -> Float[Array, " batch_size 1+dim"]:
694
+ return batch.obs_batch_dict["pinn_in"]
650
695
 
651
696
  def __call__(self, *args, **kwargs):
652
697
  return self.evaluate(*args, **kwargs)
653
698
 
654
699
  def evaluate(
655
- self, params: Params, batch: PDENonStatioBatch
656
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
700
+ self, params: Params[Array], batch: PDENonStatioBatch
701
+ ) -> tuple[Float[Array, " "], LossDictPDENonStatio]:
657
702
  """
658
703
  Evaluate the loss function at a batch of points for given parameters.
659
704
 
@@ -670,6 +715,7 @@ class LossPDENonStatio(LossPDEStatio):
670
715
  inputs/outputs/parameters
671
716
  """
672
717
  omega_batch = batch.initial_batch
718
+ assert omega_batch is not None
673
719
 
674
720
  # Retrieve the optional eq_params_batch
675
721
  # and update eq_params with the latter
@@ -682,17 +728,19 @@ class LossPDENonStatio(LossPDEStatio):
682
728
 
683
729
  # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
684
730
  # mse_observation_loss we use the evaluate from parent class
685
- partial_mse, partial_mse_terms = super().evaluate(params, batch)
731
+ partial_mse, partial_mse_terms = super().evaluate(params, batch) # type: ignore
732
+ # ignore because batch is not PDEStatioBatch. We could use typing.cast though
686
733
 
687
734
  # initial condition
688
735
  if self.initial_condition_fun is not None:
689
736
  mse_initial_condition = initial_condition_apply(
690
737
  self.u,
691
738
  omega_batch,
692
- _set_derivatives(params, self.derivative_keys.initial_condition),
739
+ _set_derivatives(params, self.derivative_keys.initial_condition), # type: ignore
693
740
  (0,) + vmap_in_axes_params,
694
741
  self.initial_condition_fun,
695
- self.loss_weights.initial_condition,
742
+ self.t0, # type: ignore can't get the narrowing in __post_init__
743
+ self.loss_weights.initial_condition, # type: ignore
696
744
  )
697
745
  else:
698
746
  mse_initial_condition = jnp.array(0.0)
@@ -704,425 +752,3 @@ class LossPDENonStatio(LossPDEStatio):
704
752
  **partial_mse_terms,
705
753
  "initial_condition": mse_initial_condition,
706
754
  }
707
-
708
-
709
- class SystemLossPDE(eqx.Module):
710
- r"""
711
- Class to implement a system of PDEs.
712
- The goal is to give maximum freedom to the user. The class is created with
713
- a dict of dynamic loss, and dictionaries of all the objects that are used
714
- in LossPDENonStatio and LossPDEStatio. When then iterate
715
- over the dynamic losses that compose the system. All the PINNs with all the
716
- parameter dictionaries are passed as arguments to each dynamic loss
717
- evaluate functions; it is inside the dynamic loss that specification are
718
- performed.
719
-
720
- **Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
721
- Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
722
- solution.
723
-
724
- Parameters
725
- ----------
726
- u_dict : Dict[str, eqx.Module]
727
- dict of PINNs
728
- loss_weights : LossWeightsPDEDict
729
- A dictionary of LossWeightsODE
730
- derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
731
- A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
732
- specifying what field of `params`
733
- should be used during gradient computations for each of the terms of
734
- the total loss, for each of the loss in the system. Default is
735
- `"nn_params`" everywhere.
736
- dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
737
- A dict of dynamic part of the loss, basically the differential
738
- operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
739
- key_dict : Dict[str, Key], default=None
740
- A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
741
- match that of u_dict. See LossPDEStatio or LossPDENonStatio for
742
- more details.
743
- omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
744
- A dict of of function or of dict of functions or of None
745
- (see doc for `omega_boundary_fun` in
746
- LossPDEStatio or LossPDENonStatio). Default is None.
747
- Must share the keys of `u_dict`.
748
- omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
749
- A dict of strings or of dict of strings or of None
750
- (see doc for `omega_boundary_condition_dict` in
751
- LossPDEStatio or LossPDENonStatio). Default is None.
752
- Must share the keys of `u_dict`
753
- omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
754
- A dict of slices or of dict of slices or of None
755
- (see doc for `omega_boundary_dim` in
756
- LossPDEStatio or LossPDENonStatio). Default is None.
757
- Must share the keys of `u_dict`
758
- initial_condition_fun_dict : Dict[str, Callable | None], default=None
759
- A dict of functions representing the temporal initial condition (None
760
- value is possible). If None
761
- (default) then no temporal boundary condition is applied
762
- Must share the keys of `u_dict`
763
- norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
764
- A dict of Monte-Carlo sample points for computing the
765
- normalization constant. Default is None.
766
- Must share the keys of `u_dict`
767
- norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
768
- A dict of jnp.array with the same keys as `u_dict`. The importance
769
- sampling weights for Monte-Carlo integration of the
770
- normalization constant for each element of u_dict. Must be provided if
771
- `norm_samples_dict` is provided.
772
- `norm_weights_dict[key]` should have the same leading dimension as
773
- `norm_samples_dict[key]` for each `key`.
774
- Alternatively, the user can pass a float or an integer.
775
- For each key, an array of similar shape to `norm_samples_dict[key]`
776
- or shape `(1,)` is expected. These corresponds to the weights $w_k =
777
- \frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
778
- the Monte-Carlo samples. Default is None
779
- obs_slice_dict : Dict[str, slice | None] | None, default=None
780
- dict of obs_slice, with keys from `u_dict` to designate the
781
- output(s) channels that are forced to observed values, for each
782
- PINNs. Default is None. But if a value is given, all the entries of
783
- `u_dict` must be represented here with default value `jnp.s_[...]`
784
- if no particular slice is to be given
785
- params : InitVar[ParamsDict], default=None
786
- The main Params object of the problem needed to instanciate the
787
- DerivativeKeysODE if the latter is not specified.
788
-
789
- """
790
-
791
- # NOTE static=True only for leaf attributes that are not valid JAX types
792
- # (ie. jax.Array cannot be static) and that we do not expect to change
793
- u_dict: Dict[str, eqx.Module]
794
- dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
795
- key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
796
- derivative_keys_dict: Dict[
797
- str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
798
- ] = eqx.field(kw_only=True, default=None)
799
- omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
800
- eqx.field(kw_only=True, default=None, static=True)
801
- )
802
- omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
803
- eqx.field(kw_only=True, default=None, static=True)
804
- )
805
- omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
806
- eqx.field(kw_only=True, default=None, static=True)
807
- )
808
- initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
809
- kw_only=True, default=None, static=True
810
- )
811
- norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
812
- eqx.field(kw_only=True, default=None)
813
- )
814
- norm_weights_dict: (
815
- Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
816
- ) = eqx.field(kw_only=True, default=None)
817
- obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
818
- kw_only=True, default=None, static=True
819
- )
820
-
821
- # For the user loss_weights are passed as a LossWeightsPDEDict (with internal
822
- # dictionary having keys in u_dict and / or dynamic_loss_dict)
823
- loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
824
- kw_only=True, default=None
825
- )
826
- params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
827
-
828
- # following have init=False and are set in the __post_init__
829
- u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
830
- init=False
831
- )
832
- derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
833
- eqx.field(init=False)
834
- )
835
- u_dict_with_none: Dict[str, None] = eqx.field(init=False)
836
- # internally the loss weights are handled with a dictionary
837
- _loss_weights: Dict[str, dict] = eqx.field(init=False)
838
-
839
- def __post_init__(self, loss_weights=None, params_dict=None):
840
- # a dictionary that will be useful at different places
841
- self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
842
- # First, for all the optional dict,
843
- # if the user did not provide at all this optional argument,
844
- # we make sure there is a null ponderating loss_weight and we
845
- # create a dummy dict with the required keys and all the values to
846
- # None
847
- if self.key_dict is None:
848
- self.key_dict = self.u_dict_with_none
849
- if self.omega_boundary_fun_dict is None:
850
- self.omega_boundary_fun_dict = self.u_dict_with_none
851
- if self.omega_boundary_condition_dict is None:
852
- self.omega_boundary_condition_dict = self.u_dict_with_none
853
- if self.omega_boundary_dim_dict is None:
854
- self.omega_boundary_dim_dict = self.u_dict_with_none
855
- if self.initial_condition_fun_dict is None:
856
- self.initial_condition_fun_dict = self.u_dict_with_none
857
- if self.norm_samples_dict is None:
858
- self.norm_samples_dict = self.u_dict_with_none
859
- if self.norm_weights_dict is None:
860
- self.norm_weights_dict = self.u_dict_with_none
861
- if self.obs_slice_dict is None:
862
- self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
863
- if self.u_dict.keys() != self.obs_slice_dict.keys():
864
- raise ValueError("obs_slice_dict should have same keys as u_dict")
865
- if self.derivative_keys_dict is None:
866
- self.derivative_keys_dict = {
867
- k: None
868
- for k in set(
869
- list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
870
- )
871
- }
872
- # set() because we can have duplicate entries and in this case we
873
- # say it corresponds to the same derivative_keys_dict entry
874
- # we need both because the constraints (all but dyn_loss) will be
875
- # done by iterating on u_dict while the dyn_loss will be by
876
- # iterating on dynamic_loss_dict. So each time we will require dome
877
- # derivative_keys_dict
878
-
879
- # derivative keys for the u_constraints. Note that we create missing
880
- # DerivativeKeysODE around a Params object and not ParamsDict
881
- # this works because u_dict.keys == params_dict.nn_params.keys()
882
- for k in self.u_dict.keys():
883
- if self.derivative_keys_dict[k] is None:
884
- if self.u_dict[k].eq_type == "statio_PDE":
885
- self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
886
- params=params_dict.extract_params(k)
887
- )
888
- else:
889
- self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
890
- params=params_dict.extract_params(k)
891
- )
892
-
893
- # Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
894
- if (
895
- self.u_dict.keys() != self.key_dict.keys()
896
- or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
897
- or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
898
- or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
899
- or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
900
- or self.u_dict.keys() != self.norm_samples_dict.keys()
901
- or self.u_dict.keys() != self.norm_weights_dict.keys()
902
- ):
903
- raise ValueError("All the dicts concerning the PINNs should have same keys")
904
-
905
- self._loss_weights = self.set_loss_weights(loss_weights)
906
-
907
- # Third, in order not to benefit from LossPDEStatio and
908
- # LossPDENonStatio and in order to factorize code, we create internally
909
- # some losses object to implement the constraints on the solutions.
910
- # We will not use the dynamic loss term
911
- self.u_constraints_dict = {}
912
- for i in self.u_dict.keys():
913
- if self.u_dict[i].eq_type == "statio_PDE":
914
- self.u_constraints_dict[i] = LossPDEStatio(
915
- u=self.u_dict[i],
916
- loss_weights=LossWeightsPDENonStatio(
917
- dyn_loss=0.0,
918
- norm_loss=1.0,
919
- boundary_loss=1.0,
920
- observations=1.0,
921
- initial_condition=1.0,
922
- ),
923
- dynamic_loss=None,
924
- key=self.key_dict[i],
925
- derivative_keys=self.derivative_keys_dict[i],
926
- omega_boundary_fun=self.omega_boundary_fun_dict[i],
927
- omega_boundary_condition=self.omega_boundary_condition_dict[i],
928
- omega_boundary_dim=self.omega_boundary_dim_dict[i],
929
- norm_samples=self.norm_samples_dict[i],
930
- norm_weights=self.norm_weights_dict[i],
931
- obs_slice=self.obs_slice_dict[i],
932
- )
933
- elif self.u_dict[i].eq_type == "nonstatio_PDE":
934
- self.u_constraints_dict[i] = LossPDENonStatio(
935
- u=self.u_dict[i],
936
- loss_weights=LossWeightsPDENonStatio(
937
- dyn_loss=0.0,
938
- norm_loss=1.0,
939
- boundary_loss=1.0,
940
- observations=1.0,
941
- initial_condition=1.0,
942
- ),
943
- dynamic_loss=None,
944
- key=self.key_dict[i],
945
- derivative_keys=self.derivative_keys_dict[i],
946
- omega_boundary_fun=self.omega_boundary_fun_dict[i],
947
- omega_boundary_condition=self.omega_boundary_condition_dict[i],
948
- omega_boundary_dim=self.omega_boundary_dim_dict[i],
949
- initial_condition_fun=self.initial_condition_fun_dict[i],
950
- norm_samples=self.norm_samples_dict[i],
951
- norm_weights=self.norm_weights_dict[i],
952
- obs_slice=self.obs_slice_dict[i],
953
- )
954
- else:
955
- raise ValueError(
956
- "Wrong value for self.u_dict[i].eq_type[i], "
957
- f"got {self.u_dict[i].eq_type[i]}"
958
- )
959
-
960
- # derivative keys for the dynamic loss. Note that we create a
961
- # DerivativeKeysODE around a ParamsDict object because a whole
962
- # params_dict is feed to DynamicLoss.evaluate functions (extract_params
963
- # happen inside it)
964
- self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
965
-
966
- # also make sure we only have PINNs or SPINNs
967
- if not (
968
- all(isinstance(value, PINN) for value in self.u_dict.values())
969
- or all(isinstance(value, SPINN) for value in self.u_dict.values())
970
- ):
971
- raise ValueError(
972
- "We only accept dictionary of PINNs or dictionary of SPINNs"
973
- )
974
-
975
- def set_loss_weights(
976
- self, loss_weights_init: LossWeightsPDEDict
977
- ) -> dict[str, dict]:
978
- """
979
- This rather complex function enables the user to specify a simple
980
- loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
981
- for ponderating values being applied to all the equations of the
982
- system... So all the transformations are handled here
983
- """
984
- _loss_weights = {}
985
- for k in fields(loss_weights_init):
986
- v = getattr(loss_weights_init, k.name)
987
- if isinstance(v, dict):
988
- for vv in v.keys():
989
- if not isinstance(vv, (int, float)) and not (
990
- isinstance(vv, Array)
991
- and ((vv.shape == (1,) or len(vv.shape) == 0))
992
- ):
993
- # TODO improve that
994
- raise ValueError(
995
- f"loss values cannot be vectorial here, got {vv}"
996
- )
997
- if k.name == "dyn_loss":
998
- if v.keys() == self.dynamic_loss_dict.keys():
999
- _loss_weights[k.name] = v
1000
- else:
1001
- raise ValueError(
1002
- "Keys in nested dictionary of loss_weights"
1003
- " do not match dynamic_loss_dict keys"
1004
- )
1005
- else:
1006
- if v.keys() == self.u_dict.keys():
1007
- _loss_weights[k.name] = v
1008
- else:
1009
- raise ValueError(
1010
- "Keys in nested dictionary of loss_weights"
1011
- " do not match u_dict keys"
1012
- )
1013
- if v is None:
1014
- _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
1015
- else:
1016
- if not isinstance(v, (int, float)) and not (
1017
- isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
1018
- ):
1019
- # TODO improve that
1020
- raise ValueError(f"loss values cannot be vectorial here, got {v}")
1021
- if k.name == "dyn_loss":
1022
- _loss_weights[k.name] = {
1023
- kk: v for kk in self.dynamic_loss_dict.keys()
1024
- }
1025
- else:
1026
- _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
1027
- return _loss_weights
1028
-
1029
- def __call__(self, *args, **kwargs):
1030
- return self.evaluate(*args, **kwargs)
1031
-
1032
- def evaluate(
1033
- self,
1034
- params_dict: ParamsDict,
1035
- batch: PDEStatioBatch | PDENonStatioBatch,
1036
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
1037
- """
1038
- Evaluate the loss function at a batch of points for given parameters.
1039
-
1040
-
1041
- Parameters
1042
- ---------
1043
- params_dict
1044
- Parameters at which the losses of the system are evaluated
1045
- batch
1046
- Such named tuples are composed of batch of points in the
1047
- domain, a batch of points in the domain
1048
- border, (a batch of time points a for PDENonStatioBatch) and an
1049
- optional additional batch of parameters (eg. for metamodeling)
1050
- and an optional additional batch of observed
1051
- inputs/outputs/parameters
1052
- """
1053
- if self.u_dict.keys() != params_dict.nn_params.keys():
1054
- raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1055
-
1056
- vmap_in_axes = (0,)
1057
-
1058
- # Retrieve the optional eq_params_batch
1059
- # and update eq_params with the latter
1060
- # and update vmap_in_axes
1061
- if batch.param_batch_dict is not None:
1062
- eq_params_batch_dict = batch.param_batch_dict
1063
-
1064
- # feed the eq_params with the batch
1065
- for k in eq_params_batch_dict.keys():
1066
- params_dict.eq_params[k] = eq_params_batch_dict[k]
1067
-
1068
- vmap_in_axes_params = _get_vmap_in_axes_params(
1069
- batch.param_batch_dict, params_dict
1070
- )
1071
-
1072
- def dyn_loss_for_one_key(dyn_loss, loss_weight):
1073
- """The function used in tree_map"""
1074
- return dynamic_loss_apply(
1075
- dyn_loss.evaluate,
1076
- self.u_dict,
1077
- (
1078
- batch.domain_batch
1079
- if isinstance(batch, PDEStatioBatch)
1080
- else batch.domain_batch
1081
- ),
1082
- _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1083
- vmap_in_axes + vmap_in_axes_params,
1084
- loss_weight,
1085
- u_type=list(self.u_dict.values())[0].__class__.__base__,
1086
- )
1087
-
1088
- dyn_loss_mse_dict = jax.tree_util.tree_map(
1089
- dyn_loss_for_one_key,
1090
- self.dynamic_loss_dict,
1091
- self._loss_weights["dyn_loss"],
1092
- is_leaf=lambda x: isinstance(
1093
- x, (PDEStatio, PDENonStatio)
1094
- ), # before when dynamic losses
1095
- # where plain (unregister pytree) node classes, we could not traverse
1096
- # this level. Now that dynamic losses are eqx.Module they can be
1097
- # traversed by tree map recursion. Hence we need to specify to that
1098
- # we want to stop at this level
1099
- )
1100
- mse_dyn_loss = jax.tree_util.tree_reduce(
1101
- lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
1102
- )
1103
-
1104
- # boundary conditions, normalization conditions, observation_loss,
1105
- # initial condition... loss this is done via the internal
1106
- # LossPDEStatio and NonStatio
1107
- loss_weight_struct = {
1108
- "dyn_loss": "*",
1109
- "norm_loss": "*",
1110
- "boundary_loss": "*",
1111
- "observations": "*",
1112
- "initial_condition": "*",
1113
- }
1114
- # we need to do the following for the tree_mapping to work
1115
- if batch.obs_batch_dict is None:
1116
- batch = append_obs_batch(batch, self.u_dict_with_none)
1117
- total_loss, res_dict = constraints_system_loss_apply(
1118
- self.u_constraints_dict,
1119
- batch,
1120
- params_dict,
1121
- self._loss_weights,
1122
- loss_weight_struct,
1123
- )
1124
-
1125
- # Add the mse_dyn_loss from the previous computations
1126
- total_loss += mse_dyn_loss
1127
- res_dict["dyn_loss"] += mse_dyn_loss
1128
- return total_loss, res_dict