jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py CHANGED
@@ -1,16 +1,16 @@
1
- # pylint: disable=unsubscriptable-object, no-member
2
1
  """
3
2
  Main module to implement a PDE loss in jinns
4
3
  """
4
+
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  import abc
10
- from dataclasses import InitVar, fields
11
- from typing import TYPE_CHECKING, Dict, Callable
10
+ from dataclasses import InitVar
11
+ from typing import TYPE_CHECKING, Callable, TypedDict
12
+ from types import EllipsisType
12
13
  import warnings
13
- import jax
14
14
  import jax.numpy as jnp
15
15
  import equinox as eqx
16
16
  from jaxtyping import Float, Array, Key, Int
@@ -20,9 +20,7 @@ from jinns.loss._loss_utils import (
20
20
  normalization_loss_apply,
21
21
  observations_loss_apply,
22
22
  initial_condition_apply,
23
- constraints_system_loss_apply,
24
23
  )
25
- from jinns.data._DataGenerators import append_obs_batch
26
24
  from jinns.parameters._params import (
27
25
  _get_vmap_in_axes_params,
28
26
  _update_eq_params_dict,
@@ -32,19 +30,30 @@ from jinns.parameters._derivative_keys import (
32
30
  DerivativeKeysPDEStatio,
33
31
  DerivativeKeysPDENonStatio,
34
32
  )
33
+ from jinns.loss._abstract_loss import AbstractLoss
35
34
  from jinns.loss._loss_weights import (
36
35
  LossWeightsPDEStatio,
37
36
  LossWeightsPDENonStatio,
38
- LossWeightsPDEDict,
39
37
  )
40
- from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
41
- from jinns.utils._pinn import PINN
42
- from jinns.utils._spinn import SPINN
43
38
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
44
39
 
45
40
 
46
41
  if TYPE_CHECKING:
47
- from jinns.utils._types import *
42
+ # imports for type hints only
43
+ from jinns.parameters._params import Params
44
+ from jinns.nn._abstract_pinn import AbstractPINN
45
+ from jinns.loss import PDENonStatio, PDEStatio
46
+ from jinns.utils._types import BoundaryConditionFun
47
+
48
+ class LossDictPDEStatio(TypedDict):
49
+ dyn_loss: Float[Array, " "]
50
+ norm_loss: Float[Array, " "]
51
+ boundary_loss: Float[Array, " "]
52
+ observations: Float[Array, " "]
53
+
54
+ class LossDictPDENonStatio(LossDictPDEStatio):
55
+ initial_condition: Float[Array, " "]
56
+
48
57
 
49
58
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
50
59
  "dirichlet",
@@ -53,8 +62,8 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
53
62
  ]
54
63
 
55
64
 
56
- class _LossPDEAbstract(eqx.Module):
57
- """
65
+ class _LossPDEAbstract(AbstractLoss):
66
+ r"""
58
67
  Parameters
59
68
  ----------
60
69
 
@@ -69,11 +78,11 @@ class _LossPDEAbstract(eqx.Module):
69
78
  Fields can be "nn_params", "eq_params" or "both". Those that should not
70
79
  be updated will have a `jax.lax.stop_gradient` called on them. Default
71
80
  is `"nn_params"` for each composant of the loss.
72
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
81
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
73
82
  The function to be matched in the border condition (can be None) or a
74
83
  dictionary of such functions as values and keys as described
75
84
  in `omega_boundary_condition`.
76
- omega_boundary_condition : str | Dict[str, str], default=None
85
+ omega_boundary_condition : str | dict[str, str], default=None
77
86
  Either None (no condition, by default), or a string defining
78
87
  the boundary condition (Dirichlet or Von Neumann),
79
88
  or a dictionary with such strings as values. In this case,
@@ -84,24 +93,29 @@ class _LossPDEAbstract(eqx.Module):
84
93
  a particular boundary condition on this facet.
85
94
  The facet called “xmin”, resp. “xmax” etc., in 2D,
86
95
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
87
- omega_boundary_dim : slice | Dict[str, slice], default=None
96
+ omega_boundary_dim : slice | dict[str, slice], default=None
88
97
  Either None, or a slice object or a dictionary of slice objects as
89
98
  values and keys as described in `omega_boundary_condition`.
90
99
  `omega_boundary_dim` indicates which dimension(s) of the PINN
91
100
  will be forced to match the boundary condition.
92
101
  Note that it must be a slice and not an integer
93
102
  (but a preprocessing of the user provided argument takes care of it)
94
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
95
- Fixed sample point in the space over which to compute the
103
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
104
+ Monte-Carlo sample points for computing the
96
105
  normalization constant. Default is None.
97
- norm_int_length : float, default=None
98
- A float. Must be provided if `norm_samples` is provided. The domain area
99
- (or interval length in 1D) upon which we perform the numerical
100
- integration. Default None
101
- obs_slice : slice, default=None
106
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
107
+ The importance sampling weights for Monte-Carlo integration of the
108
+ normalization constant. Must be provided if `norm_samples` is provided.
109
+ `norm_weights` should be broadcastble to
110
+ `norm_samples`.
111
+ Alternatively, the user can pass a float or an integer that will be
112
+ made broadcastable to `norm_samples`.
113
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
114
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
115
+ obs_slice : EllipsisType | slice, default=None
102
116
  slice object specifying the begininning/ending of the PINN output
103
117
  that is observed (this is then useful for multidim PINN). Default is None.
104
- params : InitVar[Params], default=None
118
+ params : InitVar[Params[Array]], default=None
105
119
  The main Params object of the problem needed to instanciate the
106
120
  DerivativeKeysODE if the latter is not specified.
107
121
  """
@@ -115,24 +129,28 @@ class _LossPDEAbstract(eqx.Module):
115
129
  loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
116
130
  kw_only=True, default=None
117
131
  )
118
- omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
132
+ omega_boundary_fun: (
133
+ BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
134
+ ) = eqx.field(kw_only=True, default=None, static=True)
135
+ omega_boundary_condition: str | dict[str, str] | None = eqx.field(
119
136
  kw_only=True, default=None, static=True
120
137
  )
121
- omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
138
+ omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
122
139
  kw_only=True, default=None, static=True
123
140
  )
124
- omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
125
- kw_only=True, default=None, static=True
141
+ norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
142
+ kw_only=True, default=None
126
143
  )
127
- norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
144
+ norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
128
145
  kw_only=True, default=None
129
146
  )
130
- norm_int_length: float | None = eqx.field(kw_only=True, default=None)
131
- obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
147
+ obs_slice: EllipsisType | slice | None = eqx.field(
148
+ kw_only=True, default=None, static=True
149
+ )
132
150
 
133
- params: InitVar[Params] = eqx.field(kw_only=True, default=None)
151
+ params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
134
152
 
135
- def __post_init__(self, params=None):
153
+ def __post_init__(self, params: Params[Array] | None = None):
136
154
  """
137
155
  Note that neither __init__ or __post_init__ are called when udating a
138
156
  Module with eqx.tree_at
@@ -222,6 +240,11 @@ class _LossPDEAbstract(eqx.Module):
222
240
  )
223
241
 
224
242
  if isinstance(self.omega_boundary_fun, dict):
243
+ if not isinstance(self.omega_boundary_dim, dict):
244
+ raise ValueError(
245
+ "If omega_boundary_fun is a dict then"
246
+ " omega_boundary_dim should also be a dict"
247
+ )
225
248
  if self.omega_boundary_dim is None:
226
249
  self.omega_boundary_dim = {
227
250
  k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
@@ -251,15 +274,34 @@ class _LossPDEAbstract(eqx.Module):
251
274
  if not isinstance(self.omega_boundary_dim, slice):
252
275
  raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
253
276
 
254
- if self.norm_samples is not None and self.norm_int_length is None:
255
- raise ValueError("self.norm_samples and norm_int_length must be provided")
277
+ if self.norm_samples is not None:
278
+ if self.norm_weights is None:
279
+ raise ValueError(
280
+ "`norm_weights` must be provided when `norm_samples` is used!"
281
+ )
282
+ if isinstance(self.norm_weights, (int, float)):
283
+ self.norm_weights = self.norm_weights * jnp.ones(
284
+ (self.norm_samples.shape[0],)
285
+ )
286
+ if isinstance(self.norm_weights, Array):
287
+ if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
288
+ raise ValueError(
289
+ "self.norm_weights and "
290
+ "self.norm_samples must have the same leading dimension"
291
+ )
292
+ else:
293
+ raise ValueError("Wrong type for self.norm_weights")
294
+
295
+ @abc.abstractmethod
296
+ def __call__(self, *_, **__):
297
+ pass
256
298
 
257
299
  @abc.abstractmethod
258
300
  def evaluate(
259
301
  self: eqx.Module,
260
- params: Params,
302
+ params: Params[Array],
261
303
  batch: PDEStatioBatch | PDENonStatioBatch,
262
- ) -> tuple[Float, dict]:
304
+ ) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
263
305
  raise NotImplementedError
264
306
 
265
307
 
@@ -276,9 +318,9 @@ class LossPDEStatio(_LossPDEAbstract):
276
318
 
277
319
  Parameters
278
320
  ----------
279
- u : eqx.Module
321
+ u : AbstractPINN
280
322
  the PINN
281
- dynamic_loss : DynamicLoss
323
+ dynamic_loss : PDEStatio
282
324
  the stationary PDE dynamic part of the loss, basically the differential
283
325
  operator $\mathcal{N}[u](x)$. Should implement a method
284
326
  `dynamic_loss.evaluate(x, u, params)`.
@@ -301,11 +343,11 @@ class LossPDEStatio(_LossPDEAbstract):
301
343
  Fields can be "nn_params", "eq_params" or "both". Those that should not
302
344
  be updated will have a `jax.lax.stop_gradient` called on them. Default
303
345
  is `"nn_params"` for each composant of the loss.
304
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
346
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
305
347
  The function to be matched in the border condition (can be None) or a
306
348
  dictionary of such functions as values and keys as described
307
349
  in `omega_boundary_condition`.
308
- omega_boundary_condition : str | Dict[str, str], default=None
350
+ omega_boundary_condition : str | dict[str, str], default=None
309
351
  Either None (no condition, by default), or a string defining
310
352
  the boundary condition (Dirichlet or Von Neumann),
311
353
  or a dictionary with such strings as values. In this case,
@@ -316,24 +358,28 @@ class LossPDEStatio(_LossPDEAbstract):
316
358
  a particular boundary condition on this facet.
317
359
  The facet called “xmin”, resp. “xmax” etc., in 2D,
318
360
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
319
- omega_boundary_dim : slice | Dict[str, slice], default=None
361
+ omega_boundary_dim : slice | dict[str, slice], default=None
320
362
  Either None, or a slice object or a dictionary of slice objects as
321
363
  values and keys as described in `omega_boundary_condition`.
322
364
  `omega_boundary_dim` indicates which dimension(s) of the PINN
323
365
  will be forced to match the boundary condition.
324
366
  Note that it must be a slice and not an integer
325
367
  (but a preprocessing of the user provided argument takes care of it)
326
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
327
- Fixed sample point in the space over which to compute the
368
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
369
+ Monte-Carlo sample points for computing the
328
370
  normalization constant. Default is None.
329
- norm_int_length : float, default=None
330
- A float. Must be provided if `norm_samples` is provided. The domain area
331
- (or interval length in 1D) upon which we perform the numerical
332
- integration. Default None
371
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
372
+ The importance sampling weights for Monte-Carlo integration of the
373
+ normalization constant. Must be provided if `norm_samples` is provided.
374
+ `norm_weights` should have the same leading dimension as
375
+ `norm_samples`.
376
+ Alternatively, the user can pass a float or an integer.
377
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
378
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
333
379
  obs_slice : slice, default=None
334
380
  slice object specifying the begininning/ending of the PINN output
335
381
  that is observed (this is then useful for multidim PINN). Default is None.
336
- params : InitVar[Params], default=None
382
+ params : InitVar[Params[Array]], default=None
337
383
  The main Params object of the problem needed to instanciate the
338
384
  DerivativeKeysODE if the latter is not specified.
339
385
 
@@ -348,13 +394,13 @@ class LossPDEStatio(_LossPDEAbstract):
348
394
  # NOTE static=True only for leaf attributes that are not valid JAX types
349
395
  # (ie. jax.Array cannot be static) and that we do not expect to change
350
396
 
351
- u: eqx.Module
352
- dynamic_loss: DynamicLoss | None
397
+ u: AbstractPINN
398
+ dynamic_loss: PDEStatio | None
353
399
  key: Key | None = eqx.field(kw_only=True, default=None)
354
400
 
355
401
  vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
356
402
 
357
- def __post_init__(self, params=None):
403
+ def __post_init__(self, params: Params[Array] | None = None):
358
404
  """
359
405
  Note that neither __init__ or __post_init__ are called when udating a
360
406
  Module with eqx.tree_at!
@@ -368,25 +414,27 @@ class LossPDEStatio(_LossPDEAbstract):
368
414
 
369
415
  def _get_dynamic_loss_batch(
370
416
  self, batch: PDEStatioBatch
371
- ) -> Float[Array, "batch_size dimension"]:
417
+ ) -> Float[Array, " batch_size dimension"]:
372
418
  return batch.domain_batch
373
419
 
374
420
  def _get_normalization_loss_batch(
375
421
  self, _
376
- ) -> Float[Array, "nb_norm_samples dimension"]:
377
- return (self.norm_samples,)
422
+ ) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
423
+ return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
424
+
425
+ # we could have used typing.cast though
378
426
 
379
427
  def _get_observations_loss_batch(
380
428
  self, batch: PDEStatioBatch
381
- ) -> Float[Array, "batch_size obs_dim"]:
382
- return (batch.obs_batch_dict["pinn_in"],)
429
+ ) -> Float[Array, " batch_size obs_dim"]:
430
+ return batch.obs_batch_dict["pinn_in"]
383
431
 
384
432
  def __call__(self, *args, **kwargs):
385
433
  return self.evaluate(*args, **kwargs)
386
434
 
387
435
  def evaluate(
388
- self, params: Params, batch: PDEStatioBatch
389
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
436
+ self, params: Params[Array], batch: PDEStatioBatch
437
+ ) -> tuple[Float[Array, " "], LossDictPDEStatio]:
390
438
  """
391
439
  Evaluate the loss function at a batch of points for given parameters.
392
440
 
@@ -417,9 +465,9 @@ class LossPDEStatio(_LossPDEAbstract):
417
465
  self.dynamic_loss.evaluate,
418
466
  self.u,
419
467
  self._get_dynamic_loss_batch(batch),
420
- _set_derivatives(params, self.derivative_keys.dyn_loss),
468
+ _set_derivatives(params, self.derivative_keys.dyn_loss), # type: ignore
421
469
  self.vmap_in_axes + vmap_in_axes_params,
422
- self.loss_weights.dyn_loss,
470
+ self.loss_weights.dyn_loss, # type: ignore
423
471
  )
424
472
  else:
425
473
  mse_dyn_loss = jnp.array(0.0)
@@ -429,24 +477,28 @@ class LossPDEStatio(_LossPDEAbstract):
429
477
  mse_norm_loss = normalization_loss_apply(
430
478
  self.u,
431
479
  self._get_normalization_loss_batch(batch),
432
- _set_derivatives(params, self.derivative_keys.norm_loss),
480
+ _set_derivatives(params, self.derivative_keys.norm_loss), # type: ignore
433
481
  vmap_in_axes_params,
434
- self.norm_int_length,
435
- self.loss_weights.norm_loss,
482
+ self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
483
+ self.loss_weights.norm_loss, # type: ignore
436
484
  )
437
485
  else:
438
486
  mse_norm_loss = jnp.array(0.0)
439
487
 
440
488
  # boundary part
441
- if self.omega_boundary_condition is not None:
489
+ if (
490
+ self.omega_boundary_condition is not None
491
+ and self.omega_boundary_dim is not None
492
+ and self.omega_boundary_fun is not None
493
+ ): # pyright cannot narrow down the three None otherwise as it is class attribute
442
494
  mse_boundary_loss = boundary_condition_apply(
443
495
  self.u,
444
496
  batch,
445
- _set_derivatives(params, self.derivative_keys.boundary_loss),
497
+ _set_derivatives(params, self.derivative_keys.boundary_loss), # type: ignore
446
498
  self.omega_boundary_fun,
447
499
  self.omega_boundary_condition,
448
500
  self.omega_boundary_dim,
449
- self.loss_weights.boundary_loss,
501
+ self.loss_weights.boundary_loss, # type: ignore
450
502
  )
451
503
  else:
452
504
  mse_boundary_loss = jnp.array(0.0)
@@ -459,10 +511,10 @@ class LossPDEStatio(_LossPDEAbstract):
459
511
  mse_observation_loss = observations_loss_apply(
460
512
  self.u,
461
513
  self._get_observations_loss_batch(batch),
462
- _set_derivatives(params, self.derivative_keys.observations),
514
+ _set_derivatives(params, self.derivative_keys.observations), # type: ignore
463
515
  self.vmap_in_axes + vmap_in_axes_params,
464
516
  batch.obs_batch_dict["val"],
465
- self.loss_weights.observations,
517
+ self.loss_weights.observations, # type: ignore
466
518
  self.obs_slice,
467
519
  )
468
520
  else:
@@ -478,8 +530,6 @@ class LossPDEStatio(_LossPDEAbstract):
478
530
  "norm_loss": mse_norm_loss,
479
531
  "boundary_loss": mse_boundary_loss,
480
532
  "observations": mse_observation_loss,
481
- "initial_condition": jnp.array(0.0), # for compatibility in the
482
- # tree_map of SystemLoss
483
533
  }
484
534
  )
485
535
 
@@ -500,9 +550,9 @@ class LossPDENonStatio(LossPDEStatio):
500
550
 
501
551
  Parameters
502
552
  ----------
503
- u : eqx.Module
553
+ u : AbstractPINN
504
554
  the PINN
505
- dynamic_loss : DynamicLoss
555
+ dynamic_loss : PDENonStatio
506
556
  the non stationary PDE dynamic part of the loss, basically the differential
507
557
  operator $\mathcal{N}[u](t, x)$. Should implement a method
508
558
  `dynamic_loss.evaluate(t, x, u, params)`.
@@ -526,11 +576,11 @@ class LossPDENonStatio(LossPDEStatio):
526
576
  Fields can be "nn_params", "eq_params" or "both". Those that should not
527
577
  be updated will have a `jax.lax.stop_gradient` called on them. Default
528
578
  is `"nn_params"` for each composant of the loss.
529
- omega_boundary_fun : Callable | Dict[str, Callable], default=None
579
+ omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
530
580
  The function to be matched in the border condition (can be None) or a
531
581
  dictionary of such functions as values and keys as described
532
582
  in `omega_boundary_condition`.
533
- omega_boundary_condition : str | Dict[str, str], default=None
583
+ omega_boundary_condition : str | dict[str, str], default=None
534
584
  Either None (no condition, by default), or a string defining
535
585
  the boundary condition (Dirichlet or Von Neumann),
536
586
  or a dictionary with such strings as values. In this case,
@@ -541,37 +591,46 @@ class LossPDENonStatio(LossPDEStatio):
541
591
  a particular boundary condition on this facet.
542
592
  The facet called “xmin”, resp. “xmax” etc., in 2D,
543
593
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
544
- omega_boundary_dim : slice | Dict[str, slice], default=None
594
+ omega_boundary_dim : slice | dict[str, slice], default=None
545
595
  Either None, or a slice object or a dictionary of slice objects as
546
596
  values and keys as described in `omega_boundary_condition`.
547
597
  `omega_boundary_dim` indicates which dimension(s) of the PINN
548
598
  will be forced to match the boundary condition.
549
599
  Note that it must be a slice and not an integer
550
600
  (but a preprocessing of the user provided argument takes care of it)
551
- norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
552
- Fixed sample point in the space over which to compute the
601
+ norm_samples : Float[Array, " nb_norm_samples dimension"], default=None
602
+ Monte-Carlo sample points for computing the
553
603
  normalization constant. Default is None.
554
- norm_int_length : float, default=None
555
- A float. Must be provided if `norm_samples` is provided. The domain area
556
- (or interval length in 1D) upon which we perform the numerical
557
- integration. Default None
604
+ norm_weights : Float[Array, " nb_norm_samples"] | float | int, default=None
605
+ The importance sampling weights for Monte-Carlo integration of the
606
+ normalization constant. Must be provided if `norm_samples` is provided.
607
+ `norm_weights` should have the same leading dimension as
608
+ `norm_samples`.
609
+ Alternatively, the user can pass a float or an integer.
610
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
611
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
558
612
  obs_slice : slice, default=None
559
613
  slice object specifying the begininning/ending of the PINN output
560
614
  that is observed (this is then useful for multidim PINN). Default is None.
615
+ t0 : float | Float[Array, " 1"], default=None
616
+ The time at which to apply the initial condition. If None, the time
617
+ is set to `0` by default.
561
618
  initial_condition_fun : Callable, default=None
562
- A function representing the temporal initial condition. If None
563
- (default) then no initial condition is applied
564
- params : InitVar[Params], default=None
565
- The main Params object of the problem needed to instanciate the
566
- DerivativeKeysODE if the latter is not specified.
619
+ A function representing the initial condition at `t0`. If None
620
+ (default) then no initial condition is applied.
621
+ params : InitVar[Params[Array]], default=None
622
+ The main `Params` object of the problem needed to instanciate the
623
+ `DerivativeKeysODE` if the latter is not specified.
567
624
 
568
625
  """
569
626
 
627
+ dynamic_loss: PDENonStatio | None
570
628
  # NOTE static=True only for leaf attributes that are not valid JAX types
571
629
  # (ie. jax.Array cannot be static) and that we do not expect to change
572
630
  initial_condition_fun: Callable | None = eqx.field(
573
631
  kw_only=True, default=None, static=True
574
632
  )
633
+ t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
575
634
 
576
635
  _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
577
636
  _max_norm_time_slices: Int = eqx.field(init=False, static=True)
@@ -593,6 +652,21 @@ class LossPDENonStatio(LossPDEStatio):
593
652
  "Initial condition wasn't provided. Be sure to cover for that"
594
653
  "case (e.g by. hardcoding it into the PINN output)."
595
654
  )
655
+ # some checks for t0
656
+ if isinstance(self.t0, Array):
657
+ if not self.t0.shape: # e.g. user input: jnp.array(0.)
658
+ self.t0 = jnp.array([self.t0])
659
+ elif self.t0.shape != (1,):
660
+ raise ValueError(
661
+ f"Wrong self.t0 input. It should be"
662
+ f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
663
+ )
664
+ elif isinstance(self.t0, float): # e.g. user input: 0
665
+ self.t0 = jnp.array([self.t0])
666
+ elif self.t0 is None:
667
+ self.t0 = jnp.array([0])
668
+ else:
669
+ raise ValueError("Wrong value for t0")
596
670
 
597
671
  # witht the variables below we avoid memory overflow since a cartesian
598
672
  # product is taken
@@ -601,28 +675,30 @@ class LossPDENonStatio(LossPDEStatio):
601
675
 
602
676
  def _get_dynamic_loss_batch(
603
677
  self, batch: PDENonStatioBatch
604
- ) -> Float[Array, "batch_size 1+dimension"]:
678
+ ) -> Float[Array, " batch_size 1+dimension"]:
605
679
  return batch.domain_batch
606
680
 
607
681
  def _get_normalization_loss_batch(
608
682
  self, batch: PDENonStatioBatch
609
- ) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
683
+ ) -> tuple[
684
+ Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
685
+ ]:
610
686
  return (
611
687
  batch.domain_batch[: self._max_norm_time_slices, 0:1],
612
- self.norm_samples[: self._max_norm_samples_omega],
688
+ self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
613
689
  )
614
690
 
615
691
  def _get_observations_loss_batch(
616
692
  self, batch: PDENonStatioBatch
617
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
618
- return (batch.obs_batch_dict["pinn_in"],)
693
+ ) -> Float[Array, " batch_size 1+dim"]:
694
+ return batch.obs_batch_dict["pinn_in"]
619
695
 
620
696
  def __call__(self, *args, **kwargs):
621
697
  return self.evaluate(*args, **kwargs)
622
698
 
623
699
  def evaluate(
624
- self, params: Params, batch: PDENonStatioBatch
625
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
700
+ self, params: Params[Array], batch: PDENonStatioBatch
701
+ ) -> tuple[Float[Array, " "], LossDictPDENonStatio]:
626
702
  """
627
703
  Evaluate the loss function at a batch of points for given parameters.
628
704
 
@@ -639,6 +715,7 @@ class LossPDENonStatio(LossPDEStatio):
639
715
  inputs/outputs/parameters
640
716
  """
641
717
  omega_batch = batch.initial_batch
718
+ assert omega_batch is not None
642
719
 
643
720
  # Retrieve the optional eq_params_batch
644
721
  # and update eq_params with the latter
@@ -651,17 +728,19 @@ class LossPDENonStatio(LossPDEStatio):
651
728
 
652
729
  # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
653
730
  # mse_observation_loss we use the evaluate from parent class
654
- partial_mse, partial_mse_terms = super().evaluate(params, batch)
731
+ partial_mse, partial_mse_terms = super().evaluate(params, batch) # type: ignore
732
+ # ignore because batch is not PDEStatioBatch. We could use typing.cast though
655
733
 
656
734
  # initial condition
657
735
  if self.initial_condition_fun is not None:
658
736
  mse_initial_condition = initial_condition_apply(
659
737
  self.u,
660
738
  omega_batch,
661
- _set_derivatives(params, self.derivative_keys.initial_condition),
739
+ _set_derivatives(params, self.derivative_keys.initial_condition), # type: ignore
662
740
  (0,) + vmap_in_axes_params,
663
741
  self.initial_condition_fun,
664
- self.loss_weights.initial_condition,
742
+ self.t0, # type: ignore can't get the narrowing in __post_init__
743
+ self.loss_weights.initial_condition, # type: ignore
665
744
  )
666
745
  else:
667
746
  mse_initial_condition = jnp.array(0.0)
@@ -673,419 +752,3 @@ class LossPDENonStatio(LossPDEStatio):
673
752
  **partial_mse_terms,
674
753
  "initial_condition": mse_initial_condition,
675
754
  }
676
-
677
-
678
- class SystemLossPDE(eqx.Module):
679
- r"""
680
- Class to implement a system of PDEs.
681
- The goal is to give maximum freedom to the user. The class is created with
682
- a dict of dynamic loss, and dictionaries of all the objects that are used
683
- in LossPDENonStatio and LossPDEStatio. When then iterate
684
- over the dynamic losses that compose the system. All the PINNs with all the
685
- parameter dictionaries are passed as arguments to each dynamic loss
686
- evaluate functions; it is inside the dynamic loss that specification are
687
- performed.
688
-
689
- **Note:** All the dictionaries (except `dynamic_loss_dict`) must have the same keys.
690
- Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
691
- solution.
692
-
693
- Parameters
694
- ----------
695
- u_dict : Dict[str, eqx.Module]
696
- dict of PINNs
697
- loss_weights : LossWeightsPDEDict
698
- A dictionary of LossWeightsODE
699
- derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
700
- A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
701
- specifying what field of `params`
702
- should be used during gradient computations for each of the terms of
703
- the total loss, for each of the loss in the system. Default is
704
- `"nn_params`" everywhere.
705
- dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
706
- A dict of dynamic part of the loss, basically the differential
707
- operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
708
- key_dict : Dict[str, Key], default=None
709
- A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
710
- match that of u_dict. See LossPDEStatio or LossPDENonStatio for
711
- more details.
712
- omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
713
- A dict of of function or of dict of functions or of None
714
- (see doc for `omega_boundary_fun` in
715
- LossPDEStatio or LossPDENonStatio). Default is None.
716
- Must share the keys of `u_dict`.
717
- omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
718
- A dict of strings or of dict of strings or of None
719
- (see doc for `omega_boundary_condition_dict` in
720
- LossPDEStatio or LossPDENonStatio). Default is None.
721
- Must share the keys of `u_dict`
722
- omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
723
- A dict of slices or of dict of slices or of None
724
- (see doc for `omega_boundary_dim` in
725
- LossPDEStatio or LossPDENonStatio). Default is None.
726
- Must share the keys of `u_dict`
727
- initial_condition_fun_dict : Dict[str, Callable | None], default=None
728
- A dict of functions representing the temporal initial condition (None
729
- value is possible). If None
730
- (default) then no temporal boundary condition is applied
731
- Must share the keys of `u_dict`
732
- norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
733
- A dict of fixed sample point in the space over which to compute the
734
- normalization constant. Default is None
735
- Must share the keys of `u_dict`
736
- norm_int_length_dict : Dict[str, float | None] | None, default=None
737
- A dict of Float. The domain area
738
- (or interval length in 1D) upon which we perform the numerical
739
- integration for each element of u_dict.
740
- Default is None
741
- Must share the keys of `u_dict`
742
- obs_slice_dict : Dict[str, slice | None] | None, default=None
743
- dict of obs_slice, with keys from `u_dict` to designate the
744
- output(s) channels that are forced to observed values, for each
745
- PINNs. Default is None. But if a value is given, all the entries of
746
- `u_dict` must be represented here with default value `jnp.s_[...]`
747
- if no particular slice is to be given
748
- params : InitVar[ParamsDict], default=None
749
- The main Params object of the problem needed to instanciate the
750
- DerivativeKeysODE if the latter is not specified.
751
-
752
- """
753
-
754
- # NOTE static=True only for leaf attributes that are not valid JAX types
755
- # (ie. jax.Array cannot be static) and that we do not expect to change
756
- u_dict: Dict[str, eqx.Module]
757
- dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
758
- key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
759
- derivative_keys_dict: Dict[
760
- str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
761
- ] = eqx.field(kw_only=True, default=None)
762
- omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
763
- eqx.field(kw_only=True, default=None, static=True)
764
- )
765
- omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
766
- eqx.field(kw_only=True, default=None, static=True)
767
- )
768
- omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
769
- eqx.field(kw_only=True, default=None, static=True)
770
- )
771
- initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
772
- kw_only=True, default=None, static=True
773
- )
774
- norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
775
- eqx.field(kw_only=True, default=None)
776
- )
777
- norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
778
- kw_only=True, default=None
779
- )
780
- obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
781
- kw_only=True, default=None, static=True
782
- )
783
-
784
- # For the user loss_weights are passed as a LossWeightsPDEDict (with internal
785
- # dictionary having keys in u_dict and / or dynamic_loss_dict)
786
- loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
787
- kw_only=True, default=None
788
- )
789
- params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
790
-
791
- # following have init=False and are set in the __post_init__
792
- u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
793
- init=False
794
- )
795
- derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
796
- eqx.field(init=False)
797
- )
798
- u_dict_with_none: Dict[str, None] = eqx.field(init=False)
799
- # internally the loss weights are handled with a dictionary
800
- _loss_weights: Dict[str, dict] = eqx.field(init=False)
801
-
802
- def __post_init__(self, loss_weights=None, params_dict=None):
803
- # a dictionary that will be useful at different places
804
- self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
805
- # First, for all the optional dict,
806
- # if the user did not provide at all this optional argument,
807
- # we make sure there is a null ponderating loss_weight and we
808
- # create a dummy dict with the required keys and all the values to
809
- # None
810
- if self.key_dict is None:
811
- self.key_dict = self.u_dict_with_none
812
- if self.omega_boundary_fun_dict is None:
813
- self.omega_boundary_fun_dict = self.u_dict_with_none
814
- if self.omega_boundary_condition_dict is None:
815
- self.omega_boundary_condition_dict = self.u_dict_with_none
816
- if self.omega_boundary_dim_dict is None:
817
- self.omega_boundary_dim_dict = self.u_dict_with_none
818
- if self.initial_condition_fun_dict is None:
819
- self.initial_condition_fun_dict = self.u_dict_with_none
820
- if self.norm_samples_dict is None:
821
- self.norm_samples_dict = self.u_dict_with_none
822
- if self.norm_int_length_dict is None:
823
- self.norm_int_length_dict = self.u_dict_with_none
824
- if self.obs_slice_dict is None:
825
- self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
826
- if self.u_dict.keys() != self.obs_slice_dict.keys():
827
- raise ValueError("obs_slice_dict should have same keys as u_dict")
828
- if self.derivative_keys_dict is None:
829
- self.derivative_keys_dict = {
830
- k: None
831
- for k in set(
832
- list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
833
- )
834
- }
835
- # set() because we can have duplicate entries and in this case we
836
- # say it corresponds to the same derivative_keys_dict entry
837
- # we need both because the constraints (all but dyn_loss) will be
838
- # done by iterating on u_dict while the dyn_loss will be by
839
- # iterating on dynamic_loss_dict. So each time we will require dome
840
- # derivative_keys_dict
841
-
842
- # derivative keys for the u_constraints. Note that we create missing
843
- # DerivativeKeysODE around a Params object and not ParamsDict
844
- # this works because u_dict.keys == params_dict.nn_params.keys()
845
- for k in self.u_dict.keys():
846
- if self.derivative_keys_dict[k] is None:
847
- if self.u_dict[k].eq_type == "statio_PDE":
848
- self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
849
- params=params_dict.extract_params(k)
850
- )
851
- else:
852
- self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
853
- params=params_dict.extract_params(k)
854
- )
855
-
856
- # Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
857
- if (
858
- self.u_dict.keys() != self.key_dict.keys()
859
- or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
860
- or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
861
- or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
862
- or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
863
- or self.u_dict.keys() != self.norm_samples_dict.keys()
864
- or self.u_dict.keys() != self.norm_int_length_dict.keys()
865
- ):
866
- raise ValueError("All the dicts concerning the PINNs should have same keys")
867
-
868
- self._loss_weights = self.set_loss_weights(loss_weights)
869
-
870
- # Third, in order not to benefit from LossPDEStatio and
871
- # LossPDENonStatio and in order to factorize code, we create internally
872
- # some losses object to implement the constraints on the solutions.
873
- # We will not use the dynamic loss term
874
- self.u_constraints_dict = {}
875
- for i in self.u_dict.keys():
876
- if self.u_dict[i].eq_type == "statio_PDE":
877
- self.u_constraints_dict[i] = LossPDEStatio(
878
- u=self.u_dict[i],
879
- loss_weights=LossWeightsPDENonStatio(
880
- dyn_loss=0.0,
881
- norm_loss=1.0,
882
- boundary_loss=1.0,
883
- observations=1.0,
884
- initial_condition=1.0,
885
- ),
886
- dynamic_loss=None,
887
- key=self.key_dict[i],
888
- derivative_keys=self.derivative_keys_dict[i],
889
- omega_boundary_fun=self.omega_boundary_fun_dict[i],
890
- omega_boundary_condition=self.omega_boundary_condition_dict[i],
891
- omega_boundary_dim=self.omega_boundary_dim_dict[i],
892
- norm_samples=self.norm_samples_dict[i],
893
- norm_int_length=self.norm_int_length_dict[i],
894
- obs_slice=self.obs_slice_dict[i],
895
- )
896
- elif self.u_dict[i].eq_type == "nonstatio_PDE":
897
- self.u_constraints_dict[i] = LossPDENonStatio(
898
- u=self.u_dict[i],
899
- loss_weights=LossWeightsPDENonStatio(
900
- dyn_loss=0.0,
901
- norm_loss=1.0,
902
- boundary_loss=1.0,
903
- observations=1.0,
904
- initial_condition=1.0,
905
- ),
906
- dynamic_loss=None,
907
- key=self.key_dict[i],
908
- derivative_keys=self.derivative_keys_dict[i],
909
- omega_boundary_fun=self.omega_boundary_fun_dict[i],
910
- omega_boundary_condition=self.omega_boundary_condition_dict[i],
911
- omega_boundary_dim=self.omega_boundary_dim_dict[i],
912
- initial_condition_fun=self.initial_condition_fun_dict[i],
913
- norm_samples=self.norm_samples_dict[i],
914
- norm_int_length=self.norm_int_length_dict[i],
915
- obs_slice=self.obs_slice_dict[i],
916
- )
917
- else:
918
- raise ValueError(
919
- "Wrong value for self.u_dict[i].eq_type[i], "
920
- f"got {self.u_dict[i].eq_type[i]}"
921
- )
922
-
923
- # derivative keys for the dynamic loss. Note that we create a
924
- # DerivativeKeysODE around a ParamsDict object because a whole
925
- # params_dict is feed to DynamicLoss.evaluate functions (extract_params
926
- # happen inside it)
927
- self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
928
-
929
- # also make sure we only have PINNs or SPINNs
930
- if not (
931
- all(isinstance(value, PINN) for value in self.u_dict.values())
932
- or all(isinstance(value, SPINN) for value in self.u_dict.values())
933
- ):
934
- raise ValueError(
935
- "We only accept dictionary of PINNs or dictionary of SPINNs"
936
- )
937
-
938
- def set_loss_weights(
939
- self, loss_weights_init: LossWeightsPDEDict
940
- ) -> dict[str, dict]:
941
- """
942
- This rather complex function enables the user to specify a simple
943
- loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
944
- for ponderating values being applied to all the equations of the
945
- system... So all the transformations are handled here
946
- """
947
- _loss_weights = {}
948
- for k in fields(loss_weights_init):
949
- v = getattr(loss_weights_init, k.name)
950
- if isinstance(v, dict):
951
- for vv in v.keys():
952
- if not isinstance(vv, (int, float)) and not (
953
- isinstance(vv, Array)
954
- and ((vv.shape == (1,) or len(vv.shape) == 0))
955
- ):
956
- # TODO improve that
957
- raise ValueError(
958
- f"loss values cannot be vectorial here, got {vv}"
959
- )
960
- if k.name == "dyn_loss":
961
- if v.keys() == self.dynamic_loss_dict.keys():
962
- _loss_weights[k.name] = v
963
- else:
964
- raise ValueError(
965
- "Keys in nested dictionary of loss_weights"
966
- " do not match dynamic_loss_dict keys"
967
- )
968
- else:
969
- if v.keys() == self.u_dict.keys():
970
- _loss_weights[k.name] = v
971
- else:
972
- raise ValueError(
973
- "Keys in nested dictionary of loss_weights"
974
- " do not match u_dict keys"
975
- )
976
- if v is None:
977
- _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
978
- else:
979
- if not isinstance(v, (int, float)) and not (
980
- isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
981
- ):
982
- # TODO improve that
983
- raise ValueError(f"loss values cannot be vectorial here, got {v}")
984
- if k.name == "dyn_loss":
985
- _loss_weights[k.name] = {
986
- kk: v for kk in self.dynamic_loss_dict.keys()
987
- }
988
- else:
989
- _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
990
- return _loss_weights
991
-
992
- def __call__(self, *args, **kwargs):
993
- return self.evaluate(*args, **kwargs)
994
-
995
- def evaluate(
996
- self,
997
- params_dict: ParamsDict,
998
- batch: PDEStatioBatch | PDENonStatioBatch,
999
- ) -> tuple[Float[Array, "1"], dict[str, float]]:
1000
- """
1001
- Evaluate the loss function at a batch of points for given parameters.
1002
-
1003
-
1004
- Parameters
1005
- ---------
1006
- params_dict
1007
- Parameters at which the losses of the system are evaluated
1008
- batch
1009
- Such named tuples are composed of batch of points in the
1010
- domain, a batch of points in the domain
1011
- border, (a batch of time points a for PDENonStatioBatch) and an
1012
- optional additional batch of parameters (eg. for metamodeling)
1013
- and an optional additional batch of observed
1014
- inputs/outputs/parameters
1015
- """
1016
- if self.u_dict.keys() != params_dict.nn_params.keys():
1017
- raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1018
-
1019
- vmap_in_axes = (0,)
1020
-
1021
- # Retrieve the optional eq_params_batch
1022
- # and update eq_params with the latter
1023
- # and update vmap_in_axes
1024
- if batch.param_batch_dict is not None:
1025
- eq_params_batch_dict = batch.param_batch_dict
1026
-
1027
- # feed the eq_params with the batch
1028
- for k in eq_params_batch_dict.keys():
1029
- params_dict.eq_params[k] = eq_params_batch_dict[k]
1030
-
1031
- vmap_in_axes_params = _get_vmap_in_axes_params(
1032
- batch.param_batch_dict, params_dict
1033
- )
1034
-
1035
- def dyn_loss_for_one_key(dyn_loss, loss_weight):
1036
- """The function used in tree_map"""
1037
- return dynamic_loss_apply(
1038
- dyn_loss.evaluate,
1039
- self.u_dict,
1040
- (
1041
- batch.domain_batch
1042
- if isinstance(batch, PDEStatioBatch)
1043
- else batch.domain_batch
1044
- ),
1045
- _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1046
- vmap_in_axes + vmap_in_axes_params,
1047
- loss_weight,
1048
- u_type=type(list(self.u_dict.values())[0]),
1049
- )
1050
-
1051
- dyn_loss_mse_dict = jax.tree_util.tree_map(
1052
- dyn_loss_for_one_key,
1053
- self.dynamic_loss_dict,
1054
- self._loss_weights["dyn_loss"],
1055
- is_leaf=lambda x: isinstance(
1056
- x, (PDEStatio, PDENonStatio)
1057
- ), # before when dynamic losses
1058
- # where plain (unregister pytree) node classes, we could not traverse
1059
- # this level. Now that dynamic losses are eqx.Module they can be
1060
- # traversed by tree map recursion. Hence we need to specify to that
1061
- # we want to stop at this level
1062
- )
1063
- mse_dyn_loss = jax.tree_util.tree_reduce(
1064
- lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
1065
- )
1066
-
1067
- # boundary conditions, normalization conditions, observation_loss,
1068
- # initial condition... loss this is done via the internal
1069
- # LossPDEStatio and NonStatio
1070
- loss_weight_struct = {
1071
- "dyn_loss": "*",
1072
- "norm_loss": "*",
1073
- "boundary_loss": "*",
1074
- "observations": "*",
1075
- "initial_condition": "*",
1076
- }
1077
- # we need to do the following for the tree_mapping to work
1078
- if batch.obs_batch_dict is None:
1079
- batch = append_obs_batch(batch, self.u_dict_with_none)
1080
- total_loss, res_dict = constraints_system_loss_apply(
1081
- self.u_constraints_dict,
1082
- batch,
1083
- params_dict,
1084
- self._loss_weights,
1085
- loss_weight_struct,
1086
- )
1087
-
1088
- # Add the mse_dyn_loss from the previous computations
1089
- total_loss += mse_dyn_loss
1090
- res_dict["dyn_loss"] += mse_dyn_loss
1091
- return total_loss, res_dict