jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +7 -7
  2. jinns/data/_AbstractDataGenerator.py +1 -1
  3. jinns/data/_Batchs.py +47 -13
  4. jinns/data/_CubicMeshPDENonStatio.py +203 -54
  5. jinns/data/_CubicMeshPDEStatio.py +190 -54
  6. jinns/data/_DataGeneratorODE.py +48 -22
  7. jinns/data/_DataGeneratorObservations.py +75 -32
  8. jinns/data/_DataGeneratorParameter.py +152 -101
  9. jinns/data/__init__.py +2 -1
  10. jinns/data/_utils.py +22 -10
  11. jinns/loss/_DynamicLoss.py +21 -20
  12. jinns/loss/_DynamicLossAbstract.py +51 -36
  13. jinns/loss/_LossODE.py +210 -191
  14. jinns/loss/_LossPDE.py +441 -368
  15. jinns/loss/_abstract_loss.py +60 -25
  16. jinns/loss/_loss_components.py +4 -25
  17. jinns/loss/_loss_utils.py +23 -0
  18. jinns/loss/_loss_weight_updates.py +6 -7
  19. jinns/loss/_loss_weights.py +34 -35
  20. jinns/nn/_abstract_pinn.py +0 -2
  21. jinns/nn/_hyperpinn.py +34 -23
  22. jinns/nn/_mlp.py +5 -4
  23. jinns/nn/_pinn.py +1 -16
  24. jinns/nn/_ppinn.py +5 -16
  25. jinns/nn/_save_load.py +11 -4
  26. jinns/nn/_spinn.py +1 -16
  27. jinns/nn/_spinn_mlp.py +5 -5
  28. jinns/nn/_utils.py +33 -38
  29. jinns/parameters/__init__.py +3 -1
  30. jinns/parameters/_derivative_keys.py +99 -41
  31. jinns/parameters/_params.py +58 -25
  32. jinns/solver/_solve.py +14 -8
  33. jinns/utils/_DictToModuleMeta.py +66 -0
  34. jinns/utils/_ItemizableModule.py +19 -0
  35. jinns/utils/__init__.py +2 -1
  36. jinns/utils/_types.py +25 -15
  37. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  38. jinns-1.6.0.dist-info/RECORD +57 -0
  39. jinns-1.5.0.dist-info/RECORD +0 -55
  40. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  41. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  42. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py CHANGED
@@ -8,23 +8,24 @@ from __future__ import (
8
8
 
9
9
  import abc
10
10
  from dataclasses import InitVar
11
- from typing import TYPE_CHECKING, Callable, TypedDict
11
+ from typing import TYPE_CHECKING, Callable, cast, Any, TypeVar, Generic
12
12
  from types import EllipsisType
13
13
  import warnings
14
14
  import jax
15
15
  import jax.numpy as jnp
16
16
  import equinox as eqx
17
- from jaxtyping import Float, Array, Key, Int
17
+ from jaxtyping import PRNGKeyArray, Float, Array
18
18
  from jinns.loss._loss_utils import (
19
19
  dynamic_loss_apply,
20
20
  boundary_condition_apply,
21
21
  normalization_loss_apply,
22
22
  observations_loss_apply,
23
23
  initial_condition_apply,
24
+ initial_condition_check,
24
25
  )
25
26
  from jinns.parameters._params import (
26
27
  _get_vmap_in_axes_params,
27
- _update_eq_params_dict,
28
+ update_eq_params,
28
29
  )
29
30
  from jinns.parameters._derivative_keys import (
30
31
  _set_derivatives,
@@ -39,24 +40,14 @@ from jinns.loss._loss_weights import (
39
40
  )
40
41
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
41
42
  from jinns.parameters._params import Params
43
+ from jinns.loss import PDENonStatio, PDEStatio
42
44
 
43
45
 
44
46
  if TYPE_CHECKING:
45
47
  # imports for type hints only
46
48
  from jinns.nn._abstract_pinn import AbstractPINN
47
- from jinns.loss import PDENonStatio, PDEStatio
48
49
  from jinns.utils._types import BoundaryConditionFun
49
50
 
50
- class LossDictPDEStatio(TypedDict):
51
- dyn_loss: Float[Array, " "]
52
- norm_loss: Float[Array, " "]
53
- boundary_loss: Float[Array, " "]
54
- observations: Float[Array, " "]
55
-
56
- class LossDictPDENonStatio(LossDictPDEStatio):
57
- initial_condition: Float[Array, " "]
58
-
59
-
60
51
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
61
52
  "dirichlet",
62
53
  "von neumann",
@@ -64,7 +55,24 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
64
55
  ]
65
56
 
66
57
 
67
- class _LossPDEAbstract(AbstractLoss):
58
+ # For the same reason that we have the TypeVar in _abstract_loss.py, we have them
59
+ # here, because _LossPDEAbstract is abtract and we cannot decide for several
60
+ # types between their statio and non-statio version.
61
+ # Assigning the type where it can be decide seems a better practice than
62
+ # assigning a type at a higher level depending on a child class type. This is
63
+ # why we now assign LossWeights and DerivativeKeys in the child class where
64
+ # they really can be decided.
65
+
66
+ L = TypeVar("L", bound=LossWeightsPDEStatio | LossWeightsPDENonStatio)
67
+ B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
68
+ C = TypeVar(
69
+ "C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
70
+ )
71
+ D = TypeVar("D", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
72
+ Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
73
+
74
+
75
+ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
68
76
  r"""
69
77
  Parameters
70
78
  ----------
@@ -77,17 +85,11 @@ class _LossPDEAbstract(AbstractLoss):
77
85
  `update_weight_method`
78
86
  update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
79
87
  Default is None meaning no update for loss weights. Otherwise a string
80
- derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
81
- Specify which field of `params` should be differentiated for each
82
- composant of the total loss. Particularily useful for inverse problems.
83
- Fields can be "nn_params", "eq_params" or "both". Those that should not
84
- be updated will have a `jax.lax.stop_gradient` called on them. Default
85
- is `"nn_params"` for each composant of the loss.
86
88
  omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
87
89
  The function to be matched in the border condition (can be None) or a
88
90
  dictionary of such functions as values and keys as described
89
91
  in `omega_boundary_condition`.
90
- omega_boundary_condition : str | dict[str, str], default=None
92
+ omega_boundary_condition : str | dict[str, str | None], default=None
91
93
  Either None (no condition, by default), or a string defining
92
94
  the boundary condition (Dirichlet or Von Neumann),
93
95
  or a dictionary with such strings as values. In this case,
@@ -98,7 +100,7 @@ class _LossPDEAbstract(AbstractLoss):
98
100
  a particular boundary condition on this facet.
99
101
  The facet called “xmin”, resp. “xmax” etc., in 2D,
100
102
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
101
- omega_boundary_dim : slice | dict[str, slice], default=None
103
+ omega_boundary_dim : int | slice | dict[str, slice], default=None
102
104
  Either None, or a slice object or a dictionary of slice objects as
103
105
  values and keys as described in `omega_boundary_condition`.
104
106
  `omega_boundary_dim` indicates which dimension(s) of the PINN
@@ -120,197 +122,286 @@ class _LossPDEAbstract(AbstractLoss):
120
122
  obs_slice : EllipsisType | slice, default=None
121
123
  slice object specifying the begininning/ending of the PINN output
122
124
  that is observed (this is then useful for multidim PINN). Default is None.
123
- params : InitVar[Params[Array]], default=None
124
- The main Params object of the problem needed to instanciate the
125
- DerivativeKeysODE if the latter is not specified.
125
+ key : Key | None
126
+ A JAX PRNG Key for the loss class treated as an attribute. Default is
127
+ None. This field is provided for future developments and additional
128
+ losses that might need some randomness. Note that special care must be
129
+ taken when splitting the key because in-place updates are forbidden in
130
+ eqx.Modules.
126
131
  """
127
132
 
128
133
  # NOTE static=True only for leaf attributes that are not valid JAX types
129
134
  # (ie. jax.Array cannot be static) and that we do not expect to change
130
- # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
131
- derivative_keys: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None = (
132
- eqx.field(kw_only=True, default=None)
133
- )
134
- loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
135
- kw_only=True, default=None
136
- )
135
+ u: eqx.AbstractVar[AbstractPINN]
136
+ dynamic_loss: eqx.AbstractVar[Y]
137
137
  omega_boundary_fun: (
138
138
  BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
139
- ) = eqx.field(kw_only=True, default=None, static=True)
140
- omega_boundary_condition: str | dict[str, str] | None = eqx.field(
141
- kw_only=True, default=None, static=True
139
+ ) = eqx.field(static=True)
140
+ omega_boundary_condition: str | dict[str, str | None] | None = eqx.field(
141
+ static=True
142
142
  )
143
- omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
144
- kw_only=True, default=None, static=True
145
- )
146
- norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
147
- kw_only=True, default=None
148
- )
149
- norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
150
- kw_only=True, default=None
151
- )
152
- obs_slice: EllipsisType | slice | None = eqx.field(
153
- kw_only=True, default=None, static=True
154
- )
155
-
156
- params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
157
-
158
- def __post_init__(self, params: Params[Array] | None = None):
159
- """
160
- Note that neither __init__ or __post_init__ are called when udating a
161
- Module with eqx.tree_at
162
- """
163
- if self.derivative_keys is None:
164
- # be default we only take gradient wrt nn_params
165
- try:
166
- self.derivative_keys = (
167
- DerivativeKeysPDENonStatio(params=params)
168
- if isinstance(self, LossPDENonStatio)
169
- else DerivativeKeysPDEStatio(params=params)
170
- )
171
- except ValueError as exc:
172
- raise ValueError(
173
- "Problem at self.derivative_keys initialization "
174
- f"received {self.derivative_keys=} and {params=}"
175
- ) from exc
176
-
177
- if self.loss_weights is None:
178
- self.loss_weights = (
179
- LossWeightsPDENonStatio()
180
- if isinstance(self, LossPDENonStatio)
181
- else LossWeightsPDEStatio()
182
- )
183
-
184
- if self.obs_slice is None:
143
+ omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
144
+ norm_samples: Float[Array, " nb_norm_samples dimension"] | None
145
+ norm_weights: Float[Array, " nb_norm_samples"] | None
146
+ obs_slice: EllipsisType | slice = eqx.field(static=True)
147
+ key: PRNGKeyArray | None
148
+
149
+ def __init__(
150
+ self,
151
+ *,
152
+ omega_boundary_fun: BoundaryConditionFun
153
+ | dict[str, BoundaryConditionFun]
154
+ | None = None,
155
+ omega_boundary_condition: str | dict[str, str | None] | None = None,
156
+ omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
157
+ norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
158
+ norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
159
+ obs_slice: EllipsisType | slice | None = None,
160
+ key: PRNGKeyArray | None = None,
161
+ **kwargs: Any, # for arguments for super()
162
+ ):
163
+ super().__init__(loss_weights=self.loss_weights, **kwargs)
164
+
165
+ if obs_slice is None:
185
166
  self.obs_slice = jnp.s_[...]
167
+ else:
168
+ self.obs_slice = obs_slice
186
169
 
187
170
  if (
188
- isinstance(self.omega_boundary_fun, dict)
189
- and not isinstance(self.omega_boundary_condition, dict)
171
+ isinstance(omega_boundary_fun, dict)
172
+ and not isinstance(omega_boundary_condition, dict)
190
173
  ) or (
191
- not isinstance(self.omega_boundary_fun, dict)
192
- and isinstance(self.omega_boundary_condition, dict)
174
+ not isinstance(omega_boundary_fun, dict)
175
+ and isinstance(omega_boundary_condition, dict)
193
176
  ):
194
177
  raise ValueError(
195
- "if one of self.omega_boundary_fun or "
196
- "self.omega_boundary_condition is dict, the other should be too."
178
+ "if one of omega_boundary_fun or "
179
+ "omega_boundary_condition is dict, the other should be too."
197
180
  )
198
181
 
199
- if self.omega_boundary_condition is None or self.omega_boundary_fun is None:
182
+ if omega_boundary_condition is None or omega_boundary_fun is None:
200
183
  warnings.warn(
201
184
  "Missing boundary function or no boundary condition."
202
185
  "Boundary function is thus ignored."
203
186
  )
204
187
  else:
205
- if isinstance(self.omega_boundary_condition, dict):
206
- for _, v in self.omega_boundary_condition.items():
188
+ if isinstance(omega_boundary_condition, dict):
189
+ for _, v in omega_boundary_condition.items():
207
190
  if v is not None and not any(
208
191
  v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
209
192
  ):
210
193
  raise NotImplementedError(
211
- f"The boundary condition {self.omega_boundary_condition} is not"
194
+ f"The boundary condition {omega_boundary_condition} is not"
212
195
  f"implemented yet. Try one of :"
213
196
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
214
197
  )
215
198
  else:
216
199
  if not any(
217
- self.omega_boundary_condition.lower() in s
200
+ omega_boundary_condition.lower() in s
218
201
  for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
219
202
  ):
220
203
  raise NotImplementedError(
221
- f"The boundary condition {self.omega_boundary_condition} is not"
204
+ f"The boundary condition {omega_boundary_condition} is not"
222
205
  f"implemented yet. Try one of :"
223
206
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
224
207
  )
225
- if isinstance(self.omega_boundary_fun, dict) and isinstance(
226
- self.omega_boundary_condition, dict
208
+ if isinstance(omega_boundary_fun, dict) and isinstance(
209
+ omega_boundary_condition, dict
210
+ ):
211
+ keys_omega_boundary_fun = cast(str, omega_boundary_fun.keys())
212
+ if (
213
+ not (
214
+ list(keys_omega_boundary_fun) == ["xmin", "xmax"]
215
+ and list(omega_boundary_condition.keys()) == ["xmin", "xmax"]
216
+ )
217
+ ) and (
218
+ not (
219
+ list(keys_omega_boundary_fun)
220
+ == ["xmin", "xmax", "ymin", "ymax"]
221
+ and list(omega_boundary_condition.keys())
222
+ == ["xmin", "xmax", "ymin", "ymax"]
223
+ )
227
224
  ):
228
- if (
229
- not (
230
- list(self.omega_boundary_fun.keys()) == ["xmin", "xmax"]
231
- and list(self.omega_boundary_condition.keys())
232
- == ["xmin", "xmax"]
233
- )
234
- ) or (
235
- not (
236
- list(self.omega_boundary_fun.keys())
237
- == ["xmin", "xmax", "ymin", "ymax"]
238
- and list(self.omega_boundary_condition.keys())
239
- == ["xmin", "xmax", "ymin", "ymax"]
240
- )
241
- ):
242
- raise ValueError(
243
- "The key order (facet order) in the "
244
- "boundary condition dictionaries is incorrect"
245
- )
225
+ raise ValueError(
226
+ "The key order (facet order) in the "
227
+ "boundary condition dictionaries is incorrect"
228
+ )
229
+
230
+ self.omega_boundary_fun = omega_boundary_fun
231
+ self.omega_boundary_condition = omega_boundary_condition
246
232
 
247
- if isinstance(self.omega_boundary_fun, dict):
248
- if not isinstance(self.omega_boundary_dim, dict):
233
+ if isinstance(omega_boundary_fun, dict):
234
+ keys_omega_boundary_fun: str = cast(str, omega_boundary_fun.keys())
235
+ if omega_boundary_dim is None:
236
+ self.omega_boundary_dim = {
237
+ k: jnp.s_[::] for k in keys_omega_boundary_fun
238
+ }
239
+ if not isinstance(omega_boundary_dim, dict):
249
240
  raise ValueError(
250
241
  "If omega_boundary_fun is a dict then"
251
242
  " omega_boundary_dim should also be a dict"
252
243
  )
253
- if self.omega_boundary_dim is None:
254
- self.omega_boundary_dim = {
255
- k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
256
- }
257
- if list(self.omega_boundary_dim.keys()) != list(
258
- self.omega_boundary_fun.keys()
259
- ):
244
+ if list(omega_boundary_dim.keys()) != list(keys_omega_boundary_fun):
260
245
  raise ValueError(
261
246
  "If omega_boundary_fun is a dict,"
262
247
  " omega_boundary_dim should be a dict with the same keys"
263
248
  )
264
- for k, v in self.omega_boundary_dim.items():
249
+ self.omega_boundary_dim = {}
250
+ for k, v in omega_boundary_dim.items():
265
251
  if isinstance(v, int):
266
252
  # rewrite it as a slice to ensure that axis does not disappear when
267
253
  # indexing
268
254
  self.omega_boundary_dim[k] = jnp.s_[v : v + 1]
255
+ else:
256
+ self.omega_boundary_dim[k] = v
269
257
 
270
258
  else:
271
- if self.omega_boundary_dim is None:
259
+ assert not isinstance(omega_boundary_dim, dict)
260
+ if omega_boundary_dim is None:
272
261
  self.omega_boundary_dim = jnp.s_[::]
273
- if isinstance(self.omega_boundary_dim, int):
262
+ elif isinstance(omega_boundary_dim, int):
274
263
  # rewrite it as a slice to ensure that axis does not disappear when
275
264
  # indexing
276
265
  self.omega_boundary_dim = jnp.s_[
277
- self.omega_boundary_dim : self.omega_boundary_dim + 1
266
+ omega_boundary_dim : omega_boundary_dim + 1
278
267
  ]
279
- if not isinstance(self.omega_boundary_dim, slice):
280
- raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
268
+ else:
269
+ assert isinstance(omega_boundary_dim, slice)
270
+ self.omega_boundary_dim = omega_boundary_dim
281
271
 
282
- if self.norm_samples is not None:
283
- if self.norm_weights is None:
272
+ if norm_samples is not None:
273
+ self.norm_samples = norm_samples
274
+ if norm_weights is None:
284
275
  raise ValueError(
285
276
  "`norm_weights` must be provided when `norm_samples` is used!"
286
277
  )
287
- if isinstance(self.norm_weights, (int, float)):
288
- self.norm_weights = self.norm_weights * jnp.ones(
278
+ if isinstance(norm_weights, (int, float)):
279
+ self.norm_weights = norm_weights * jnp.ones(
289
280
  (self.norm_samples.shape[0],)
290
281
  )
291
- if isinstance(self.norm_weights, Array):
292
- if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
282
+ else:
283
+ assert isinstance(norm_weights, Array)
284
+ if not (norm_weights.shape[0] == norm_samples.shape[0]):
293
285
  raise ValueError(
294
- "self.norm_weights and "
295
- "self.norm_samples must have the same leading dimension"
286
+ "norm_weights and "
287
+ "norm_samples must have the same leading dimension"
296
288
  )
297
- else:
298
- raise ValueError("Wrong type for self.norm_weights")
289
+ self.norm_weights = norm_weights
290
+ else:
291
+ self.norm_samples = norm_samples
292
+ self.norm_weights = None
293
+
294
+ self.key = key
299
295
 
300
296
  @abc.abstractmethod
301
- def __call__(self, *_, **__):
297
+ def _get_dynamic_loss_batch(self, batch: B) -> Array:
302
298
  pass
303
299
 
304
300
  @abc.abstractmethod
305
- def evaluate(
306
- self: eqx.Module,
301
+ def _get_normalization_loss_batch(self, batch: B) -> tuple[Array | None, ...]:
302
+ pass
303
+
304
+ def _get_dyn_loss_fun(
305
+ self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
306
+ ) -> Callable[[Params[Array]], Array] | None:
307
+ if self.dynamic_loss is not None:
308
+ dyn_loss_eval = self.dynamic_loss.evaluate
309
+ dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
310
+ lambda p: dynamic_loss_apply(
311
+ dyn_loss_eval,
312
+ self.u,
313
+ self._get_dynamic_loss_batch(batch),
314
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
315
+ self.vmap_in_axes + vmap_in_axes_params,
316
+ )
317
+ )
318
+ else:
319
+ dyn_loss_fun = None
320
+
321
+ return dyn_loss_fun
322
+
323
+ def _get_norm_loss_fun(
324
+ self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
325
+ ) -> Callable[[Params[Array]], Array] | None:
326
+ if self.norm_samples is not None:
327
+ norm_loss_fun: Callable[[Params[Array]], Array] | None = (
328
+ lambda p: normalization_loss_apply(
329
+ self.u,
330
+ cast(
331
+ tuple[Array, Array], self._get_normalization_loss_batch(batch)
332
+ ),
333
+ _set_derivatives(p, self.derivative_keys.norm_loss),
334
+ vmap_in_axes_params,
335
+ self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
336
+ )
337
+ )
338
+ else:
339
+ norm_loss_fun = None
340
+ return norm_loss_fun
341
+
342
+ def _get_boundary_loss_fun(
343
+ self, batch: B
344
+ ) -> Callable[[Params[Array]], Array] | None:
345
+ if (
346
+ self.omega_boundary_condition is not None
347
+ and self.omega_boundary_fun is not None
348
+ ):
349
+ boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
350
+ lambda p: boundary_condition_apply(
351
+ self.u,
352
+ batch,
353
+ _set_derivatives(p, self.derivative_keys.boundary_loss),
354
+ self.omega_boundary_fun, # type: ignore (we are in lambda)
355
+ self.omega_boundary_condition, # type: ignore
356
+ self.omega_boundary_dim, # type: ignore
357
+ )
358
+ )
359
+ else:
360
+ boundary_loss_fun = None
361
+
362
+ return boundary_loss_fun
363
+
364
+ def _get_obs_params_and_obs_loss_fun(
365
+ self,
366
+ batch: B,
367
+ vmap_in_axes_params: tuple[Params[int | None] | None],
307
368
  params: Params[Array],
308
- batch: PDEStatioBatch | PDENonStatioBatch,
309
- ) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
310
- raise NotImplementedError
369
+ ) -> tuple[Params[Array] | None, Callable[[Params[Array]], Array] | None]:
370
+ if batch.obs_batch_dict is not None:
371
+ # update params with the batches of observed params
372
+ params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
373
+
374
+ pinn_in, val = (
375
+ batch.obs_batch_dict["pinn_in"],
376
+ batch.obs_batch_dict["val"],
377
+ )
378
+
379
+ obs_loss_fun: Callable[[Params[Array]], Array] | None = (
380
+ lambda po: observations_loss_apply(
381
+ self.u,
382
+ pinn_in,
383
+ _set_derivatives(po, self.derivative_keys.observations),
384
+ self.vmap_in_axes + vmap_in_axes_params,
385
+ val,
386
+ self.obs_slice,
387
+ )
388
+ )
389
+ else:
390
+ params_obs = None
391
+ obs_loss_fun = None
392
+
393
+ return params_obs, obs_loss_fun
311
394
 
312
395
 
313
- class LossPDEStatio(_LossPDEAbstract):
396
+ class LossPDEStatio(
397
+ _LossPDEAbstract[
398
+ LossWeightsPDEStatio,
399
+ PDEStatioBatch,
400
+ PDEStatioComponents[Array | None],
401
+ DerivativeKeysPDEStatio,
402
+ PDEStatio | None,
403
+ ]
404
+ ):
314
405
  r"""Loss object for a stationary partial differential equation
315
406
 
316
407
  $$
@@ -325,13 +416,13 @@ class LossPDEStatio(_LossPDEAbstract):
325
416
  ----------
326
417
  u : AbstractPINN
327
418
  the PINN
328
- dynamic_loss : PDEStatio
419
+ dynamic_loss : PDEStatio | None
329
420
  the stationary PDE dynamic part of the loss, basically the differential
330
421
  operator $\mathcal{N}[u](x)$. Should implement a method
331
422
  `dynamic_loss.evaluate(x, u, params)`.
332
423
  Can be None in order to access only some part of the evaluate call
333
424
  results.
334
- key : Key
425
+ key : PRNGKeyArray
335
426
  A JAX PRNG Key for the loss class treated as an attribute. Default is
336
427
  None. This field is provided for future developments and additional
337
428
  losses that might need some randomness. Note that special care must be
@@ -343,14 +434,18 @@ class LossPDEStatio(_LossPDEAbstract):
343
434
  observations if any.
344
435
  Can be updated according to a specific algorithm. See
345
436
  `update_weight_method`
346
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
347
- Default is None meaning no update for loss weights. Otherwise a string
348
437
  derivative_keys : DerivativeKeysPDEStatio, default=None
349
438
  Specify which field of `params` should be differentiated for each
350
439
  composant of the total loss. Particularily useful for inverse problems.
351
440
  Fields can be "nn_params", "eq_params" or "both". Those that should not
352
441
  be updated will have a `jax.lax.stop_gradient` called on them. Default
353
442
  is `"nn_params"` for each composant of the loss.
443
+ params : InitVar[Params[Array]], default=None
444
+ The main Params object of the problem needed to instanciate the
445
+ DerivativeKeysODE if the latter is not specified.
446
+
447
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
448
+ Default is None meaning no update for loss weights. Otherwise a string
354
449
  omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
355
450
  The function to be matched in the border condition (can be None) or a
356
451
  dictionary of such functions as values and keys as described
@@ -366,7 +461,7 @@ class LossPDEStatio(_LossPDEAbstract):
366
461
  a particular boundary condition on this facet.
367
462
  The facet called “xmin”, resp. “xmax” etc., in 2D,
368
463
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
369
- omega_boundary_dim : slice | dict[str, slice], default=None
464
+ omega_boundary_dim : int | slice | dict[str, slice], default=None
370
465
  Either None, or a slice object or a dictionary of slice objects as
371
466
  values and keys as described in `omega_boundary_condition`.
372
467
  `omega_boundary_dim` indicates which dimension(s) of the PINN
@@ -387,10 +482,6 @@ class LossPDEStatio(_LossPDEAbstract):
387
482
  obs_slice : slice, default=None
388
483
  slice object specifying the begininning/ending of the PINN output
389
484
  that is observed (this is then useful for multidim PINN). Default is None.
390
- params : InitVar[Params[Array]], default=None
391
- The main Params object of the problem needed to instanciate the
392
- DerivativeKeysODE if the latter is not specified.
393
-
394
485
 
395
486
  Raises
396
487
  ------
@@ -404,21 +495,46 @@ class LossPDEStatio(_LossPDEAbstract):
404
495
 
405
496
  u: AbstractPINN
406
497
  dynamic_loss: PDEStatio | None
407
- key: Key | None = eqx.field(kw_only=True, default=None)
498
+ loss_weights: LossWeightsPDEStatio
499
+ derivative_keys: DerivativeKeysPDEStatio
500
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
501
+
502
+ params: InitVar[Params[Array] | None]
503
+
504
+ def __init__(
505
+ self,
506
+ *,
507
+ u: AbstractPINN,
508
+ dynamic_loss: PDEStatio | None,
509
+ loss_weights: LossWeightsPDEStatio | None = None,
510
+ derivative_keys: DerivativeKeysPDEStatio | None = None,
511
+ params: Params[Array] | None = None,
512
+ **kwargs: Any,
513
+ ):
514
+ self.u = u
515
+ if loss_weights is None:
516
+ self.loss_weights = LossWeightsPDEStatio()
517
+ else:
518
+ self.loss_weights = loss_weights
519
+ self.dynamic_loss = dynamic_loss
408
520
 
409
- vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
521
+ super().__init__(
522
+ **kwargs,
523
+ )
410
524
 
411
- def __post_init__(self, params: Params[Array] | None = None):
412
- """
413
- Note that neither __init__ or __post_init__ are called when udating a
414
- Module with eqx.tree_at!
415
- """
416
- super().__post_init__(
417
- params=params
418
- ) # because __init__ or __post_init__ of Base
419
- # class is not automatically called
525
+ if derivative_keys is None:
526
+ # be default we only take gradient wrt nn_params
527
+ try:
528
+ self.derivative_keys = DerivativeKeysPDEStatio(params=params)
529
+ except ValueError as exc:
530
+ raise ValueError(
531
+ "Problem at derivative_keys initialization "
532
+ f"received {derivative_keys=} and {params=}"
533
+ ) from exc
534
+ else:
535
+ self.derivative_keys = derivative_keys
420
536
 
421
- self.vmap_in_axes = (0,) # for x only here
537
+ self.vmap_in_axes = (0,)
422
538
 
423
539
  def _get_dynamic_loss_batch(
424
540
  self, batch: PDEStatioBatch
@@ -426,23 +542,18 @@ class LossPDEStatio(_LossPDEAbstract):
426
542
  return batch.domain_batch
427
543
 
428
544
  def _get_normalization_loss_batch(
429
- self, _
430
- ) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
431
- return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
432
-
433
- # we could have used typing.cast though
434
-
435
- def _get_observations_loss_batch(
436
545
  self, batch: PDEStatioBatch
437
- ) -> Float[Array, " batch_size obs_dim"]:
438
- return batch.obs_batch_dict["pinn_in"]
546
+ ) -> tuple[Float[Array, " nb_norm_samples dimension"] | None,]:
547
+ return (self.norm_samples,)
439
548
 
440
- def __call__(self, *args, **kwargs):
441
- return self.evaluate(*args, **kwargs)
549
+ # we could have used typing.cast though
442
550
 
443
551
  def evaluate_by_terms(
444
552
  self, params: Params[Array], batch: PDEStatioBatch
445
- ) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
553
+ ) -> tuple[
554
+ PDEStatioComponents[Float[Array, ""] | None],
555
+ PDEStatioComponents[Float[Array, ""] | None],
556
+ ]:
446
557
  """
447
558
  Evaluate the loss function at a batch of points for given parameters.
448
559
 
@@ -464,69 +575,23 @@ class LossPDEStatio(_LossPDEAbstract):
464
575
  # and update vmap_in_axes
465
576
  if batch.param_batch_dict is not None:
466
577
  # update eq_params with the batches of generated params
467
- params = _update_eq_params_dict(params, batch.param_batch_dict)
578
+ params = update_eq_params(params, batch.param_batch_dict)
468
579
 
469
580
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
470
581
 
471
582
  # dynamic part
472
- if self.dynamic_loss is not None:
473
- dyn_loss_fun = lambda p: dynamic_loss_apply(
474
- self.dynamic_loss.evaluate, # type: ignore
475
- self.u,
476
- self._get_dynamic_loss_batch(batch),
477
- _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
478
- self.vmap_in_axes + vmap_in_axes_params,
479
- )
480
- else:
481
- dyn_loss_fun = None
583
+ dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
482
584
 
483
585
  # normalization part
484
- if self.norm_samples is not None:
485
- norm_loss_fun = lambda p: normalization_loss_apply(
486
- self.u,
487
- self._get_normalization_loss_batch(batch),
488
- _set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
489
- vmap_in_axes_params,
490
- self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
491
- )
492
- else:
493
- norm_loss_fun = None
586
+ norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
494
587
 
495
588
  # boundary part
496
- if (
497
- self.omega_boundary_condition is not None
498
- and self.omega_boundary_dim is not None
499
- and self.omega_boundary_fun is not None
500
- ): # pyright cannot narrow down the three None otherwise as it is class attribute
501
- boundary_loss_fun = lambda p: boundary_condition_apply(
502
- self.u,
503
- batch,
504
- _set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
505
- self.omega_boundary_fun, # type: ignore
506
- self.omega_boundary_condition, # type: ignore
507
- self.omega_boundary_dim, # type: ignore
508
- )
509
- else:
510
- boundary_loss_fun = None
589
+ boundary_loss_fun = self._get_boundary_loss_fun(batch)
511
590
 
512
591
  # Observation mse
513
- if batch.obs_batch_dict is not None:
514
- # update params with the batches of observed params
515
- params_obs = _update_eq_params_dict(
516
- params, batch.obs_batch_dict["eq_params"]
517
- )
518
-
519
- obs_loss_fun = lambda po: observations_loss_apply(
520
- self.u,
521
- self._get_observations_loss_batch(batch),
522
- _set_derivatives(po, self.derivative_keys.observations), # type: ignore
523
- self.vmap_in_axes + vmap_in_axes_params,
524
- batch.obs_batch_dict["val"],
525
- self.obs_slice,
526
- )
527
- else:
528
- params_obs = None
529
- obs_loss_fun = None
592
+ params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
593
+ batch, vmap_in_axes_params, params
594
+ )
530
595
 
531
596
  # get the unweighted mses for each loss term as well as the gradients
532
597
  all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
@@ -538,47 +603,34 @@ class LossPDEStatio(_LossPDEAbstract):
538
603
  params, params, params, params_obs
539
604
  )
540
605
  mses_grads = jax.tree.map(
541
- lambda fun, params: self.get_gradients(fun, params),
606
+ self.get_gradients,
542
607
  all_funs,
543
608
  all_params,
544
609
  is_leaf=lambda x: x is None,
545
610
  )
546
611
  mses = jax.tree.map(
547
- lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
612
+ lambda leaf: leaf[0], # type: ignore
613
+ mses_grads,
614
+ is_leaf=lambda x: isinstance(x, tuple),
548
615
  )
549
616
  grads = jax.tree.map(
550
- lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
617
+ lambda leaf: leaf[1], # type: ignore
618
+ mses_grads,
619
+ is_leaf=lambda x: isinstance(x, tuple),
551
620
  )
552
621
 
553
622
  return mses, grads
554
623
 
555
- def evaluate(
556
- self, params: Params[Array], batch: PDEStatioBatch
557
- ) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
558
- """
559
- Evaluate the loss function at a batch of points for given parameters.
560
-
561
- We retrieve the total value itself and a PyTree with loss values for each term
562
-
563
- Parameters
564
- ---------
565
- params
566
- Parameters at which the loss is evaluated
567
- batch
568
- Composed of a batch of points in the
569
- domain, a batch of points in the domain
570
- border and an optional additional batch of parameters (eg. for
571
- metamodeling) and an optional additional batch of observed
572
- inputs/outputs/parameters
573
- """
574
- loss_terms, _ = self.evaluate_by_terms(params, batch)
575
-
576
- loss_val = self.ponderate_and_sum_loss(loss_terms)
577
-
578
- return loss_val, loss_terms
579
-
580
624
 
581
- class LossPDENonStatio(LossPDEStatio):
625
+ class LossPDENonStatio(
626
+ _LossPDEAbstract[
627
+ LossWeightsPDENonStatio,
628
+ PDENonStatioBatch,
629
+ PDENonStatioComponents[Array | None],
630
+ DerivativeKeysPDENonStatio,
631
+ PDENonStatio | None,
632
+ ]
633
+ ):
582
634
  r"""Loss object for a stationary partial differential equation
583
635
 
584
636
  $$
@@ -602,7 +654,7 @@ class LossPDENonStatio(LossPDEStatio):
602
654
  `dynamic_loss.evaluate(t, x, u, params)`.
603
655
  Can be None in order to access only some part of the evaluate call
604
656
  results.
605
- key : Key
657
+ key : PRNGKeyArray
606
658
  A JAX PRNG Key for the loss class treated as an attribute. Default is
607
659
  None. This field is provided for future developments and additional
608
660
  losses that might need some randomness. Note that special care must be
@@ -615,14 +667,31 @@ class LossPDENonStatio(LossPDEStatio):
615
667
  observations if any.
616
668
  Can be updated according to a specific algorithm. See
617
669
  `update_weight_method`
618
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
619
- Default is None meaning no update for loss weights. Otherwise a string
620
670
  derivative_keys : DerivativeKeysPDENonStatio, default=None
621
671
  Specify which field of `params` should be differentiated for each
622
672
  composant of the total loss. Particularily useful for inverse problems.
623
673
  Fields can be "nn_params", "eq_params" or "both". Those that should not
624
674
  be updated will have a `jax.lax.stop_gradient` called on them. Default
625
675
  is `"nn_params"` for each composant of the loss.
676
+ initial_condition_fun : Callable, default=None
677
+ A function representing the initial condition at `t0`. If None
678
+ (default) then no initial condition is applied.
679
+ t0 : float | Float[Array, " 1"], default=None
680
+ The time at which to apply the initial condition. If None, the time
681
+ is set to `0` by default.
682
+ max_norm_time_slices : int, default=100
683
+ The maximum number of time points in the Cartesian product with the
684
+ omega points to create the set of collocation points upon which the
685
+ normalization constant is computed.
686
+ max_norm_samples_omega : int, default=1000
687
+ The maximum number of omega points in the Cartesian product with the
688
+ time points to create the set of collocation points upon which the
689
+ normalization constant is computed.
690
+ params : InitVar[Params[Array]], default=None
691
+ The main `Params` object of the problem needed to instanciate the
692
+ `DerivativeKeysODE` if the latter is not specified.
693
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
694
+ Default is None meaning no update for loss weights. Otherwise a string
626
695
  omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
627
696
  The function to be matched in the border condition (can be None) or a
628
697
  dictionary of such functions as values and keys as described
@@ -659,68 +728,81 @@ class LossPDENonStatio(LossPDEStatio):
659
728
  obs_slice : slice, default=None
660
729
  slice object specifying the begininning/ending of the PINN output
661
730
  that is observed (this is then useful for multidim PINN). Default is None.
662
- t0 : float | Float[Array, " 1"], default=None
663
- The time at which to apply the initial condition. If None, the time
664
- is set to `0` by default.
665
- initial_condition_fun : Callable, default=None
666
- A function representing the initial condition at `t0`. If None
667
- (default) then no initial condition is applied.
668
- params : InitVar[Params[Array]], default=None
669
- The main `Params` object of the problem needed to instanciate the
670
- `DerivativeKeysODE` if the latter is not specified.
671
731
 
672
732
  """
673
733
 
734
+ u: AbstractPINN
674
735
  dynamic_loss: PDENonStatio | None
675
- # NOTE static=True only for leaf attributes that are not valid JAX types
676
- # (ie. jax.Array cannot be static) and that we do not expect to change
677
- initial_condition_fun: Callable | None = eqx.field(
678
- kw_only=True, default=None, static=True
736
+ loss_weights: LossWeightsPDENonStatio
737
+ derivative_keys: DerivativeKeysPDENonStatio
738
+ params: InitVar[Params[Array] | None]
739
+ t0: Float[Array, " "]
740
+ initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
741
+ eqx.field(static=True)
679
742
  )
680
- t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
743
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
744
+ max_norm_samples_omega: int = eqx.field(static=True)
745
+ max_norm_time_slices: int = eqx.field(static=True)
746
+
747
+ params: InitVar[Params[Array] | None]
748
+
749
+ def __init__(
750
+ self,
751
+ *,
752
+ u: AbstractPINN,
753
+ dynamic_loss: PDENonStatio | None,
754
+ loss_weights: LossWeightsPDENonStatio | None = None,
755
+ derivative_keys: DerivativeKeysPDENonStatio | None = None,
756
+ initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
757
+ | None = None,
758
+ t0: int | float | Float[Array, " "] | None = None,
759
+ max_norm_time_slices: int = 100,
760
+ max_norm_samples_omega: int = 1000,
761
+ params: Params[Array] | None = None,
762
+ **kwargs: Any,
763
+ ):
764
+ self.u = u
765
+ if loss_weights is None:
766
+ self.loss_weights = LossWeightsPDENonStatio()
767
+ else:
768
+ self.loss_weights = loss_weights
769
+ self.dynamic_loss = dynamic_loss
681
770
 
682
- _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
683
- _max_norm_time_slices: Int = eqx.field(init=False, static=True)
771
+ super().__init__(
772
+ **kwargs,
773
+ )
684
774
 
685
- def __post_init__(self, params=None):
686
- """
687
- Note that neither __init__ or __post_init__ are called when udating a
688
- Module with eqx.tree_at!
689
- """
690
- super().__post_init__(
691
- params=params
692
- ) # because __init__ or __post_init__ of Base
693
- # class is not automatically called
775
+ if derivative_keys is None:
776
+ # be default we only take gradient wrt nn_params
777
+ try:
778
+ self.derivative_keys = DerivativeKeysPDENonStatio(params=params)
779
+ except ValueError as exc:
780
+ raise ValueError(
781
+ "Problem at derivative_keys initialization "
782
+ f"received {derivative_keys=} and {params=}"
783
+ ) from exc
784
+ else:
785
+ self.derivative_keys = derivative_keys
694
786
 
695
787
  self.vmap_in_axes = (0,) # for t_x
696
788
 
697
- if self.initial_condition_fun is None:
789
+ if initial_condition_fun is None:
698
790
  warnings.warn(
699
791
  "Initial condition wasn't provided. Be sure to cover for that"
700
792
  "case (e.g by. hardcoding it into the PINN output)."
701
793
  )
702
794
  # some checks for t0
703
- if isinstance(self.t0, Array):
704
- if not self.t0.shape: # e.g. user input: jnp.array(0.)
705
- self.t0 = jnp.array([self.t0])
706
- elif self.t0.shape != (1,):
707
- raise ValueError(
708
- f"Wrong self.t0 input. It should be"
709
- f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
710
- )
711
- elif isinstance(self.t0, float): # e.g. user input: 0.
712
- self.t0 = jnp.array([self.t0])
713
- elif isinstance(self.t0, int): # e.g. user input: 0
714
- self.t0 = jnp.array([float(self.t0)])
715
- elif self.t0 is None:
795
+ if t0 is None:
716
796
  self.t0 = jnp.array([0])
717
797
  else:
718
- raise ValueError("Wrong value for t0")
798
+ self.t0 = initial_condition_check(t0, dim_size=1)
719
799
 
720
- # witht the variables below we avoid memory overflow since a cartesian
800
+ self.initial_condition_fun = initial_condition_fun
801
+
802
+ # with the variables below we avoid memory overflow since a cartesian
721
803
  # product is taken
722
- self._max_norm_time_slices = 100
723
- self._max_norm_samples_omega = 1000
804
+ self.max_norm_time_slices = max_norm_time_slices
805
+ self.max_norm_samples_omega = max_norm_samples_omega
724
806
 
725
807
  def _get_dynamic_loss_batch(
726
808
  self, batch: PDENonStatioBatch
@@ -733,18 +815,10 @@ class LossPDENonStatio(LossPDEStatio):
733
815
  Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
734
816
  ]:
735
817
  return (
736
- batch.domain_batch[: self._max_norm_time_slices, 0:1],
737
- self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
818
+ batch.domain_batch[: self.max_norm_time_slices, 0:1],
819
+ self.norm_samples[: self.max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
738
820
  )
739
821
 
740
- def _get_observations_loss_batch(
741
- self, batch: PDENonStatioBatch
742
- ) -> Float[Array, " batch_size 1+dim"]:
743
- return batch.obs_batch_dict["pinn_in"]
744
-
745
- def __call__(self, *args, **kwargs):
746
- return self.evaluate(*args, **kwargs)
747
-
748
822
  def evaluate_by_terms(
749
823
  self, params: Params[Array], batch: PDENonStatioBatch
750
824
  ) -> tuple[
@@ -766,76 +840,75 @@ class LossPDENonStatio(LossPDEStatio):
766
840
  metamodeling) and an optional additional batch of observed
767
841
  inputs/outputs/parameters
768
842
  """
769
- omega_batch = batch.initial_batch
770
- assert omega_batch is not None
843
+ omega_initial_batch = batch.initial_batch
844
+ assert omega_initial_batch is not None
771
845
 
772
846
  # Retrieve the optional eq_params_batch
773
847
  # and update eq_params with the latter
774
848
  # and update vmap_in_axes
775
849
  if batch.param_batch_dict is not None:
776
850
  # update eq_params with the batches of generated params
777
- params = _update_eq_params_dict(params, batch.param_batch_dict)
851
+ params = update_eq_params(params, batch.param_batch_dict)
778
852
 
779
853
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
780
854
 
781
- # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
782
- # mse_observation_loss we use the evaluate from parent class
783
- # As well as for their gradients
784
- partial_mses, partial_grads = super().evaluate_by_terms(params, batch) # type: ignore
785
- # ignore because batch is not PDEStatioBatch. We could use typing.cast though
855
+ # dynamic part
856
+ dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
857
+
858
+ # normalization part
859
+ norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
860
+
861
+ # boundary part
862
+ boundary_loss_fun = self._get_boundary_loss_fun(batch)
863
+
864
+ # Observation mse
865
+ params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
866
+ batch, vmap_in_axes_params, params
867
+ )
786
868
 
787
869
  # initial condition
788
870
  if self.initial_condition_fun is not None:
789
- mse_initial_condition_fun = lambda p: initial_condition_apply(
790
- self.u,
791
- omega_batch,
792
- _set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
793
- (0,) + vmap_in_axes_params,
794
- self.initial_condition_fun, # type: ignore
795
- self.t0, # type: ignore can't get the narrowing in __post_init__
796
- )
797
- mse_initial_condition, grad_initial_condition = self.get_gradients(
798
- mse_initial_condition_fun, params
871
+ mse_initial_condition_fun: Callable[[Params[Array]], Array] | None = (
872
+ lambda p: initial_condition_apply(
873
+ self.u,
874
+ omega_initial_batch,
875
+ _set_derivatives(p, self.derivative_keys.initial_condition),
876
+ (0,) + vmap_in_axes_params,
877
+ self.initial_condition_fun, # type: ignore
878
+ self.t0,
879
+ )
799
880
  )
800
881
  else:
801
- mse_initial_condition = None
802
- grad_initial_condition = None
803
-
804
- mses = PDENonStatioComponents(
805
- partial_mses.dyn_loss,
806
- partial_mses.norm_loss,
807
- partial_mses.boundary_loss,
808
- partial_mses.observations,
809
- mse_initial_condition,
810
- )
882
+ mse_initial_condition_fun = None
811
883
 
812
- grads = PDENonStatioComponents(
813
- partial_grads.dyn_loss,
814
- partial_grads.norm_loss,
815
- partial_grads.boundary_loss,
816
- partial_grads.observations,
817
- grad_initial_condition,
884
+ # get the unweighted mses for each loss term as well as the gradients
885
+ all_funs: PDENonStatioComponents[Callable[[Params[Array]], Array] | None] = (
886
+ PDENonStatioComponents(
887
+ dyn_loss_fun,
888
+ norm_loss_fun,
889
+ boundary_loss_fun,
890
+ obs_loss_fun,
891
+ mse_initial_condition_fun,
892
+ )
893
+ )
894
+ all_params: PDENonStatioComponents[Params[Array] | None] = (
895
+ PDENonStatioComponents(params, params, params, params_obs, params)
896
+ )
897
+ mses_grads = jax.tree.map(
898
+ self.get_gradients,
899
+ all_funs,
900
+ all_params,
901
+ is_leaf=lambda x: x is None,
902
+ )
903
+ mses = jax.tree.map(
904
+ lambda leaf: leaf[0], # type: ignore
905
+ mses_grads,
906
+ is_leaf=lambda x: isinstance(x, tuple),
907
+ )
908
+ grads = jax.tree.map(
909
+ lambda leaf: leaf[1], # type: ignore
910
+ mses_grads,
911
+ is_leaf=lambda x: isinstance(x, tuple),
818
912
  )
819
913
 
820
914
  return mses, grads
821
-
822
- def evaluate(
823
- self, params: Params[Array], batch: PDENonStatioBatch
824
- ) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
825
- """
826
- Evaluate the loss function at a batch of points for given parameters.
827
- We retrieve the total value itself and a PyTree with loss values for each term
828
-
829
-
830
- Parameters
831
- ---------
832
- params
833
- Parameters at which the loss is evaluated
834
- batch
835
- Composed of a batch of points in
836
- the domain, a batch of points in the domain
837
- border, a batch of time points and an optional additional batch
838
- of parameters (eg. for metamodeling) and an optional additional batch of observed
839
- inputs/outputs/parameters
840
- """
841
- return super().evaluate(params, batch) # type: ignore