jinns 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. jinns/data/_AbstractDataGenerator.py +1 -1
  2. jinns/data/_Batchs.py +47 -13
  3. jinns/data/_CubicMeshPDENonStatio.py +55 -34
  4. jinns/data/_CubicMeshPDEStatio.py +63 -35
  5. jinns/data/_DataGeneratorODE.py +48 -22
  6. jinns/data/_DataGeneratorObservations.py +75 -32
  7. jinns/data/_DataGeneratorParameter.py +152 -101
  8. jinns/data/__init__.py +2 -1
  9. jinns/data/_utils.py +22 -10
  10. jinns/loss/_DynamicLoss.py +21 -20
  11. jinns/loss/_DynamicLossAbstract.py +51 -36
  12. jinns/loss/_LossODE.py +139 -184
  13. jinns/loss/_LossPDE.py +440 -358
  14. jinns/loss/_abstract_loss.py +60 -25
  15. jinns/loss/_loss_components.py +4 -25
  16. jinns/loss/_loss_weight_updates.py +6 -7
  17. jinns/loss/_loss_weights.py +34 -35
  18. jinns/nn/_abstract_pinn.py +0 -2
  19. jinns/nn/_hyperpinn.py +34 -23
  20. jinns/nn/_mlp.py +5 -4
  21. jinns/nn/_pinn.py +1 -16
  22. jinns/nn/_ppinn.py +5 -16
  23. jinns/nn/_save_load.py +11 -4
  24. jinns/nn/_spinn.py +1 -16
  25. jinns/nn/_spinn_mlp.py +5 -5
  26. jinns/nn/_utils.py +33 -38
  27. jinns/parameters/__init__.py +3 -1
  28. jinns/parameters/_derivative_keys.py +99 -41
  29. jinns/parameters/_params.py +50 -25
  30. jinns/solver/_solve.py +3 -3
  31. jinns/utils/_DictToModuleMeta.py +66 -0
  32. jinns/utils/_ItemizableModule.py +19 -0
  33. jinns/utils/__init__.py +2 -1
  34. jinns/utils/_types.py +25 -15
  35. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  36. jinns-1.6.0.dist-info/RECORD +57 -0
  37. jinns-1.5.1.dist-info/RECORD +0 -55
  38. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  39. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  40. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  41. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py CHANGED
@@ -8,13 +8,13 @@ from __future__ import (
8
8
 
9
9
  import abc
10
10
  from dataclasses import InitVar
11
- from typing import TYPE_CHECKING, Callable, TypedDict
11
+ from typing import TYPE_CHECKING, Callable, cast, Any, TypeVar, Generic
12
12
  from types import EllipsisType
13
13
  import warnings
14
14
  import jax
15
15
  import jax.numpy as jnp
16
16
  import equinox as eqx
17
- from jaxtyping import Float, Array, Key, Int
17
+ from jaxtyping import PRNGKeyArray, Float, Array
18
18
  from jinns.loss._loss_utils import (
19
19
  dynamic_loss_apply,
20
20
  boundary_condition_apply,
@@ -25,7 +25,7 @@ from jinns.loss._loss_utils import (
25
25
  )
26
26
  from jinns.parameters._params import (
27
27
  _get_vmap_in_axes_params,
28
- _update_eq_params_dict,
28
+ update_eq_params,
29
29
  )
30
30
  from jinns.parameters._derivative_keys import (
31
31
  _set_derivatives,
@@ -40,24 +40,14 @@ from jinns.loss._loss_weights import (
40
40
  )
41
41
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
42
42
  from jinns.parameters._params import Params
43
+ from jinns.loss import PDENonStatio, PDEStatio
43
44
 
44
45
 
45
46
  if TYPE_CHECKING:
46
47
  # imports for type hints only
47
48
  from jinns.nn._abstract_pinn import AbstractPINN
48
- from jinns.loss import PDENonStatio, PDEStatio
49
49
  from jinns.utils._types import BoundaryConditionFun
50
50
 
51
- class LossDictPDEStatio(TypedDict):
52
- dyn_loss: Float[Array, " "]
53
- norm_loss: Float[Array, " "]
54
- boundary_loss: Float[Array, " "]
55
- observations: Float[Array, " "]
56
-
57
- class LossDictPDENonStatio(LossDictPDEStatio):
58
- initial_condition: Float[Array, " "]
59
-
60
-
61
51
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
62
52
  "dirichlet",
63
53
  "von neumann",
@@ -65,7 +55,24 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
65
55
  ]
66
56
 
67
57
 
68
- class _LossPDEAbstract(AbstractLoss):
58
+ # For the same reason that we have the TypeVar in _abstract_loss.py, we have them
59
+ # here, because _LossPDEAbstract is abtract and we cannot decide for several
60
+ # types between their statio and non-statio version.
61
+ # Assigning the type where it can be decide seems a better practice than
62
+ # assigning a type at a higher level depending on a child class type. This is
63
+ # why we now assign LossWeights and DerivativeKeys in the child class where
64
+ # they really can be decided.
65
+
66
+ L = TypeVar("L", bound=LossWeightsPDEStatio | LossWeightsPDENonStatio)
67
+ B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
68
+ C = TypeVar(
69
+ "C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
70
+ )
71
+ D = TypeVar("D", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
72
+ Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
73
+
74
+
75
+ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
69
76
  r"""
70
77
  Parameters
71
78
  ----------
@@ -78,17 +85,11 @@ class _LossPDEAbstract(AbstractLoss):
78
85
  `update_weight_method`
79
86
  update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
80
87
  Default is None meaning no update for loss weights. Otherwise a string
81
- derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
82
- Specify which field of `params` should be differentiated for each
83
- composant of the total loss. Particularily useful for inverse problems.
84
- Fields can be "nn_params", "eq_params" or "both". Those that should not
85
- be updated will have a `jax.lax.stop_gradient` called on them. Default
86
- is `"nn_params"` for each composant of the loss.
87
88
  omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
88
89
  The function to be matched in the border condition (can be None) or a
89
90
  dictionary of such functions as values and keys as described
90
91
  in `omega_boundary_condition`.
91
- omega_boundary_condition : str | dict[str, str], default=None
92
+ omega_boundary_condition : str | dict[str, str | None], default=None
92
93
  Either None (no condition, by default), or a string defining
93
94
  the boundary condition (Dirichlet or Von Neumann),
94
95
  or a dictionary with such strings as values. In this case,
@@ -99,7 +100,7 @@ class _LossPDEAbstract(AbstractLoss):
99
100
  a particular boundary condition on this facet.
100
101
  The facet called “xmin”, resp. “xmax” etc., in 2D,
101
102
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
102
- omega_boundary_dim : slice | dict[str, slice], default=None
103
+ omega_boundary_dim : int | slice | dict[str, slice], default=None
103
104
  Either None, or a slice object or a dictionary of slice objects as
104
105
  values and keys as described in `omega_boundary_condition`.
105
106
  `omega_boundary_dim` indicates which dimension(s) of the PINN
@@ -121,197 +122,286 @@ class _LossPDEAbstract(AbstractLoss):
121
122
  obs_slice : EllipsisType | slice, default=None
122
123
  slice object specifying the begininning/ending of the PINN output
123
124
  that is observed (this is then useful for multidim PINN). Default is None.
124
- params : InitVar[Params[Array]], default=None
125
- The main Params object of the problem needed to instanciate the
126
- DerivativeKeysODE if the latter is not specified.
125
+ key : Key | None
126
+ A JAX PRNG Key for the loss class treated as an attribute. Default is
127
+ None. This field is provided for future developments and additional
128
+ losses that might need some randomness. Note that special care must be
129
+ taken when splitting the key because in-place updates are forbidden in
130
+ eqx.Modules.
127
131
  """
128
132
 
129
133
  # NOTE static=True only for leaf attributes that are not valid JAX types
130
134
  # (ie. jax.Array cannot be static) and that we do not expect to change
131
- # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
132
- derivative_keys: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None = (
133
- eqx.field(kw_only=True, default=None)
134
- )
135
- loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
136
- kw_only=True, default=None
137
- )
135
+ u: eqx.AbstractVar[AbstractPINN]
136
+ dynamic_loss: eqx.AbstractVar[Y]
138
137
  omega_boundary_fun: (
139
138
  BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
140
- ) = eqx.field(kw_only=True, default=None, static=True)
141
- omega_boundary_condition: str | dict[str, str] | None = eqx.field(
142
- kw_only=True, default=None, static=True
139
+ ) = eqx.field(static=True)
140
+ omega_boundary_condition: str | dict[str, str | None] | None = eqx.field(
141
+ static=True
143
142
  )
144
- omega_boundary_dim: slice | dict[str, slice] | None = eqx.field(
145
- kw_only=True, default=None, static=True
146
- )
147
- norm_samples: Float[Array, " nb_norm_samples dimension"] | None = eqx.field(
148
- kw_only=True, default=None
149
- )
150
- norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = eqx.field(
151
- kw_only=True, default=None
152
- )
153
- obs_slice: EllipsisType | slice | None = eqx.field(
154
- kw_only=True, default=None, static=True
155
- )
156
-
157
- params: InitVar[Params[Array]] = eqx.field(kw_only=True, default=None)
158
-
159
- def __post_init__(self, params: Params[Array] | None = None):
160
- """
161
- Note that neither __init__ or __post_init__ are called when udating a
162
- Module with eqx.tree_at
163
- """
164
- if self.derivative_keys is None:
165
- # be default we only take gradient wrt nn_params
166
- try:
167
- self.derivative_keys = (
168
- DerivativeKeysPDENonStatio(params=params)
169
- if isinstance(self, LossPDENonStatio)
170
- else DerivativeKeysPDEStatio(params=params)
171
- )
172
- except ValueError as exc:
173
- raise ValueError(
174
- "Problem at self.derivative_keys initialization "
175
- f"received {self.derivative_keys=} and {params=}"
176
- ) from exc
177
-
178
- if self.loss_weights is None:
179
- self.loss_weights = (
180
- LossWeightsPDENonStatio()
181
- if isinstance(self, LossPDENonStatio)
182
- else LossWeightsPDEStatio()
183
- )
184
-
185
- if self.obs_slice is None:
143
+ omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
144
+ norm_samples: Float[Array, " nb_norm_samples dimension"] | None
145
+ norm_weights: Float[Array, " nb_norm_samples"] | None
146
+ obs_slice: EllipsisType | slice = eqx.field(static=True)
147
+ key: PRNGKeyArray | None
148
+
149
+ def __init__(
150
+ self,
151
+ *,
152
+ omega_boundary_fun: BoundaryConditionFun
153
+ | dict[str, BoundaryConditionFun]
154
+ | None = None,
155
+ omega_boundary_condition: str | dict[str, str | None] | None = None,
156
+ omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
157
+ norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
158
+ norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
159
+ obs_slice: EllipsisType | slice | None = None,
160
+ key: PRNGKeyArray | None = None,
161
+ **kwargs: Any, # for arguments for super()
162
+ ):
163
+ super().__init__(loss_weights=self.loss_weights, **kwargs)
164
+
165
+ if obs_slice is None:
186
166
  self.obs_slice = jnp.s_[...]
167
+ else:
168
+ self.obs_slice = obs_slice
187
169
 
188
170
  if (
189
- isinstance(self.omega_boundary_fun, dict)
190
- and not isinstance(self.omega_boundary_condition, dict)
171
+ isinstance(omega_boundary_fun, dict)
172
+ and not isinstance(omega_boundary_condition, dict)
191
173
  ) or (
192
- not isinstance(self.omega_boundary_fun, dict)
193
- and isinstance(self.omega_boundary_condition, dict)
174
+ not isinstance(omega_boundary_fun, dict)
175
+ and isinstance(omega_boundary_condition, dict)
194
176
  ):
195
177
  raise ValueError(
196
- "if one of self.omega_boundary_fun or "
197
- "self.omega_boundary_condition is dict, the other should be too."
178
+ "if one of omega_boundary_fun or "
179
+ "omega_boundary_condition is dict, the other should be too."
198
180
  )
199
181
 
200
- if self.omega_boundary_condition is None or self.omega_boundary_fun is None:
182
+ if omega_boundary_condition is None or omega_boundary_fun is None:
201
183
  warnings.warn(
202
184
  "Missing boundary function or no boundary condition."
203
185
  "Boundary function is thus ignored."
204
186
  )
205
187
  else:
206
- if isinstance(self.omega_boundary_condition, dict):
207
- for _, v in self.omega_boundary_condition.items():
188
+ if isinstance(omega_boundary_condition, dict):
189
+ for _, v in omega_boundary_condition.items():
208
190
  if v is not None and not any(
209
191
  v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
210
192
  ):
211
193
  raise NotImplementedError(
212
- f"The boundary condition {self.omega_boundary_condition} is not"
194
+ f"The boundary condition {omega_boundary_condition} is not"
213
195
  f"implemented yet. Try one of :"
214
196
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
215
197
  )
216
198
  else:
217
199
  if not any(
218
- self.omega_boundary_condition.lower() in s
200
+ omega_boundary_condition.lower() in s
219
201
  for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
220
202
  ):
221
203
  raise NotImplementedError(
222
- f"The boundary condition {self.omega_boundary_condition} is not"
204
+ f"The boundary condition {omega_boundary_condition} is not"
223
205
  f"implemented yet. Try one of :"
224
206
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
225
207
  )
226
- if isinstance(self.omega_boundary_fun, dict) and isinstance(
227
- self.omega_boundary_condition, dict
208
+ if isinstance(omega_boundary_fun, dict) and isinstance(
209
+ omega_boundary_condition, dict
210
+ ):
211
+ keys_omega_boundary_fun = cast(str, omega_boundary_fun.keys())
212
+ if (
213
+ not (
214
+ list(keys_omega_boundary_fun) == ["xmin", "xmax"]
215
+ and list(omega_boundary_condition.keys()) == ["xmin", "xmax"]
216
+ )
217
+ ) and (
218
+ not (
219
+ list(keys_omega_boundary_fun)
220
+ == ["xmin", "xmax", "ymin", "ymax"]
221
+ and list(omega_boundary_condition.keys())
222
+ == ["xmin", "xmax", "ymin", "ymax"]
223
+ )
228
224
  ):
229
- if (
230
- not (
231
- list(self.omega_boundary_fun.keys()) == ["xmin", "xmax"]
232
- and list(self.omega_boundary_condition.keys())
233
- == ["xmin", "xmax"]
234
- )
235
- ) or (
236
- not (
237
- list(self.omega_boundary_fun.keys())
238
- == ["xmin", "xmax", "ymin", "ymax"]
239
- and list(self.omega_boundary_condition.keys())
240
- == ["xmin", "xmax", "ymin", "ymax"]
241
- )
242
- ):
243
- raise ValueError(
244
- "The key order (facet order) in the "
245
- "boundary condition dictionaries is incorrect"
246
- )
225
+ raise ValueError(
226
+ "The key order (facet order) in the "
227
+ "boundary condition dictionaries is incorrect"
228
+ )
247
229
 
248
- if isinstance(self.omega_boundary_fun, dict):
249
- if not isinstance(self.omega_boundary_dim, dict):
230
+ self.omega_boundary_fun = omega_boundary_fun
231
+ self.omega_boundary_condition = omega_boundary_condition
232
+
233
+ if isinstance(omega_boundary_fun, dict):
234
+ keys_omega_boundary_fun: str = cast(str, omega_boundary_fun.keys())
235
+ if omega_boundary_dim is None:
236
+ self.omega_boundary_dim = {
237
+ k: jnp.s_[::] for k in keys_omega_boundary_fun
238
+ }
239
+ if not isinstance(omega_boundary_dim, dict):
250
240
  raise ValueError(
251
241
  "If omega_boundary_fun is a dict then"
252
242
  " omega_boundary_dim should also be a dict"
253
243
  )
254
- if self.omega_boundary_dim is None:
255
- self.omega_boundary_dim = {
256
- k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
257
- }
258
- if list(self.omega_boundary_dim.keys()) != list(
259
- self.omega_boundary_fun.keys()
260
- ):
244
+ if list(omega_boundary_dim.keys()) != list(keys_omega_boundary_fun):
261
245
  raise ValueError(
262
246
  "If omega_boundary_fun is a dict,"
263
247
  " omega_boundary_dim should be a dict with the same keys"
264
248
  )
265
- for k, v in self.omega_boundary_dim.items():
249
+ self.omega_boundary_dim = {}
250
+ for k, v in omega_boundary_dim.items():
266
251
  if isinstance(v, int):
267
252
  # rewrite it as a slice to ensure that axis does not disappear when
268
253
  # indexing
269
254
  self.omega_boundary_dim[k] = jnp.s_[v : v + 1]
255
+ else:
256
+ self.omega_boundary_dim[k] = v
270
257
 
271
258
  else:
272
- if self.omega_boundary_dim is None:
259
+ assert not isinstance(omega_boundary_dim, dict)
260
+ if omega_boundary_dim is None:
273
261
  self.omega_boundary_dim = jnp.s_[::]
274
- if isinstance(self.omega_boundary_dim, int):
262
+ elif isinstance(omega_boundary_dim, int):
275
263
  # rewrite it as a slice to ensure that axis does not disappear when
276
264
  # indexing
277
265
  self.omega_boundary_dim = jnp.s_[
278
- self.omega_boundary_dim : self.omega_boundary_dim + 1
266
+ omega_boundary_dim : omega_boundary_dim + 1
279
267
  ]
280
- if not isinstance(self.omega_boundary_dim, slice):
281
- raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
268
+ else:
269
+ assert isinstance(omega_boundary_dim, slice)
270
+ self.omega_boundary_dim = omega_boundary_dim
282
271
 
283
- if self.norm_samples is not None:
284
- if self.norm_weights is None:
272
+ if norm_samples is not None:
273
+ self.norm_samples = norm_samples
274
+ if norm_weights is None:
285
275
  raise ValueError(
286
276
  "`norm_weights` must be provided when `norm_samples` is used!"
287
277
  )
288
- if isinstance(self.norm_weights, (int, float)):
289
- self.norm_weights = self.norm_weights * jnp.ones(
278
+ if isinstance(norm_weights, (int, float)):
279
+ self.norm_weights = norm_weights * jnp.ones(
290
280
  (self.norm_samples.shape[0],)
291
281
  )
292
- if isinstance(self.norm_weights, Array):
293
- if not (self.norm_weights.shape[0] == self.norm_samples.shape[0]):
282
+ else:
283
+ assert isinstance(norm_weights, Array)
284
+ if not (norm_weights.shape[0] == norm_samples.shape[0]):
294
285
  raise ValueError(
295
- "self.norm_weights and "
296
- "self.norm_samples must have the same leading dimension"
286
+ "norm_weights and "
287
+ "norm_samples must have the same leading dimension"
297
288
  )
298
- else:
299
- raise ValueError("Wrong type for self.norm_weights")
289
+ self.norm_weights = norm_weights
290
+ else:
291
+ self.norm_samples = norm_samples
292
+ self.norm_weights = None
293
+
294
+ self.key = key
300
295
 
301
296
  @abc.abstractmethod
302
- def __call__(self, *_, **__):
297
+ def _get_dynamic_loss_batch(self, batch: B) -> Array:
303
298
  pass
304
299
 
305
300
  @abc.abstractmethod
306
- def evaluate(
307
- self: eqx.Module,
301
+ def _get_normalization_loss_batch(self, batch: B) -> tuple[Array | None, ...]:
302
+ pass
303
+
304
+ def _get_dyn_loss_fun(
305
+ self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
306
+ ) -> Callable[[Params[Array]], Array] | None:
307
+ if self.dynamic_loss is not None:
308
+ dyn_loss_eval = self.dynamic_loss.evaluate
309
+ dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
310
+ lambda p: dynamic_loss_apply(
311
+ dyn_loss_eval,
312
+ self.u,
313
+ self._get_dynamic_loss_batch(batch),
314
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
315
+ self.vmap_in_axes + vmap_in_axes_params,
316
+ )
317
+ )
318
+ else:
319
+ dyn_loss_fun = None
320
+
321
+ return dyn_loss_fun
322
+
323
+ def _get_norm_loss_fun(
324
+ self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
325
+ ) -> Callable[[Params[Array]], Array] | None:
326
+ if self.norm_samples is not None:
327
+ norm_loss_fun: Callable[[Params[Array]], Array] | None = (
328
+ lambda p: normalization_loss_apply(
329
+ self.u,
330
+ cast(
331
+ tuple[Array, Array], self._get_normalization_loss_batch(batch)
332
+ ),
333
+ _set_derivatives(p, self.derivative_keys.norm_loss),
334
+ vmap_in_axes_params,
335
+ self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
336
+ )
337
+ )
338
+ else:
339
+ norm_loss_fun = None
340
+ return norm_loss_fun
341
+
342
+ def _get_boundary_loss_fun(
343
+ self, batch: B
344
+ ) -> Callable[[Params[Array]], Array] | None:
345
+ if (
346
+ self.omega_boundary_condition is not None
347
+ and self.omega_boundary_fun is not None
348
+ ):
349
+ boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
350
+ lambda p: boundary_condition_apply(
351
+ self.u,
352
+ batch,
353
+ _set_derivatives(p, self.derivative_keys.boundary_loss),
354
+ self.omega_boundary_fun, # type: ignore (we are in lambda)
355
+ self.omega_boundary_condition, # type: ignore
356
+ self.omega_boundary_dim, # type: ignore
357
+ )
358
+ )
359
+ else:
360
+ boundary_loss_fun = None
361
+
362
+ return boundary_loss_fun
363
+
364
+ def _get_obs_params_and_obs_loss_fun(
365
+ self,
366
+ batch: B,
367
+ vmap_in_axes_params: tuple[Params[int | None] | None],
308
368
  params: Params[Array],
309
- batch: PDEStatioBatch | PDENonStatioBatch,
310
- ) -> tuple[Float[Array, " "], LossDictPDEStatio | LossDictPDENonStatio]:
311
- raise NotImplementedError
369
+ ) -> tuple[Params[Array] | None, Callable[[Params[Array]], Array] | None]:
370
+ if batch.obs_batch_dict is not None:
371
+ # update params with the batches of observed params
372
+ params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
373
+
374
+ pinn_in, val = (
375
+ batch.obs_batch_dict["pinn_in"],
376
+ batch.obs_batch_dict["val"],
377
+ )
312
378
 
379
+ obs_loss_fun: Callable[[Params[Array]], Array] | None = (
380
+ lambda po: observations_loss_apply(
381
+ self.u,
382
+ pinn_in,
383
+ _set_derivatives(po, self.derivative_keys.observations),
384
+ self.vmap_in_axes + vmap_in_axes_params,
385
+ val,
386
+ self.obs_slice,
387
+ )
388
+ )
389
+ else:
390
+ params_obs = None
391
+ obs_loss_fun = None
392
+
393
+ return params_obs, obs_loss_fun
313
394
 
314
- class LossPDEStatio(_LossPDEAbstract):
395
+
396
+ class LossPDEStatio(
397
+ _LossPDEAbstract[
398
+ LossWeightsPDEStatio,
399
+ PDEStatioBatch,
400
+ PDEStatioComponents[Array | None],
401
+ DerivativeKeysPDEStatio,
402
+ PDEStatio | None,
403
+ ]
404
+ ):
315
405
  r"""Loss object for a stationary partial differential equation
316
406
 
317
407
  $$
@@ -326,13 +416,13 @@ class LossPDEStatio(_LossPDEAbstract):
326
416
  ----------
327
417
  u : AbstractPINN
328
418
  the PINN
329
- dynamic_loss : PDEStatio
419
+ dynamic_loss : PDEStatio | None
330
420
  the stationary PDE dynamic part of the loss, basically the differential
331
421
  operator $\mathcal{N}[u](x)$. Should implement a method
332
422
  `dynamic_loss.evaluate(x, u, params)`.
333
423
  Can be None in order to access only some part of the evaluate call
334
424
  results.
335
- key : Key
425
+ key : PRNGKeyArray
336
426
  A JAX PRNG Key for the loss class treated as an attribute. Default is
337
427
  None. This field is provided for future developments and additional
338
428
  losses that might need some randomness. Note that special care must be
@@ -344,14 +434,18 @@ class LossPDEStatio(_LossPDEAbstract):
344
434
  observations if any.
345
435
  Can be updated according to a specific algorithm. See
346
436
  `update_weight_method`
347
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
348
- Default is None meaning no update for loss weights. Otherwise a string
349
437
  derivative_keys : DerivativeKeysPDEStatio, default=None
350
438
  Specify which field of `params` should be differentiated for each
351
439
  composant of the total loss. Particularily useful for inverse problems.
352
440
  Fields can be "nn_params", "eq_params" or "both". Those that should not
353
441
  be updated will have a `jax.lax.stop_gradient` called on them. Default
354
442
  is `"nn_params"` for each composant of the loss.
443
+ params : InitVar[Params[Array]], default=None
444
+ The main Params object of the problem needed to instanciate the
445
+ DerivativeKeysODE if the latter is not specified.
446
+
447
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
448
+ Default is None meaning no update for loss weights. Otherwise a string
355
449
  omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
356
450
  The function to be matched in the border condition (can be None) or a
357
451
  dictionary of such functions as values and keys as described
@@ -367,7 +461,7 @@ class LossPDEStatio(_LossPDEAbstract):
367
461
  a particular boundary condition on this facet.
368
462
  The facet called “xmin”, resp. “xmax” etc., in 2D,
369
463
  refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
370
- omega_boundary_dim : slice | dict[str, slice], default=None
464
+ omega_boundary_dim : int | slice | dict[str, slice], default=None
371
465
  Either None, or a slice object or a dictionary of slice objects as
372
466
  values and keys as described in `omega_boundary_condition`.
373
467
  `omega_boundary_dim` indicates which dimension(s) of the PINN
@@ -388,10 +482,6 @@ class LossPDEStatio(_LossPDEAbstract):
388
482
  obs_slice : slice, default=None
389
483
  slice object specifying the begininning/ending of the PINN output
390
484
  that is observed (this is then useful for multidim PINN). Default is None.
391
- params : InitVar[Params[Array]], default=None
392
- The main Params object of the problem needed to instanciate the
393
- DerivativeKeysODE if the latter is not specified.
394
-
395
485
 
396
486
  Raises
397
487
  ------
@@ -405,21 +495,46 @@ class LossPDEStatio(_LossPDEAbstract):
405
495
 
406
496
  u: AbstractPINN
407
497
  dynamic_loss: PDEStatio | None
408
- key: Key | None = eqx.field(kw_only=True, default=None)
498
+ loss_weights: LossWeightsPDEStatio
499
+ derivative_keys: DerivativeKeysPDEStatio
500
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
501
+
502
+ params: InitVar[Params[Array] | None]
503
+
504
+ def __init__(
505
+ self,
506
+ *,
507
+ u: AbstractPINN,
508
+ dynamic_loss: PDEStatio | None,
509
+ loss_weights: LossWeightsPDEStatio | None = None,
510
+ derivative_keys: DerivativeKeysPDEStatio | None = None,
511
+ params: Params[Array] | None = None,
512
+ **kwargs: Any,
513
+ ):
514
+ self.u = u
515
+ if loss_weights is None:
516
+ self.loss_weights = LossWeightsPDEStatio()
517
+ else:
518
+ self.loss_weights = loss_weights
519
+ self.dynamic_loss = dynamic_loss
409
520
 
410
- vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
521
+ super().__init__(
522
+ **kwargs,
523
+ )
411
524
 
412
- def __post_init__(self, params: Params[Array] | None = None):
413
- """
414
- Note that neither __init__ or __post_init__ are called when udating a
415
- Module with eqx.tree_at!
416
- """
417
- super().__post_init__(
418
- params=params
419
- ) # because __init__ or __post_init__ of Base
420
- # class is not automatically called
525
+ if derivative_keys is None:
526
+ # be default we only take gradient wrt nn_params
527
+ try:
528
+ self.derivative_keys = DerivativeKeysPDEStatio(params=params)
529
+ except ValueError as exc:
530
+ raise ValueError(
531
+ "Problem at derivative_keys initialization "
532
+ f"received {derivative_keys=} and {params=}"
533
+ ) from exc
534
+ else:
535
+ self.derivative_keys = derivative_keys
421
536
 
422
- self.vmap_in_axes = (0,) # for x only here
537
+ self.vmap_in_axes = (0,)
423
538
 
424
539
  def _get_dynamic_loss_batch(
425
540
  self, batch: PDEStatioBatch
@@ -427,23 +542,18 @@ class LossPDEStatio(_LossPDEAbstract):
427
542
  return batch.domain_batch
428
543
 
429
544
  def _get_normalization_loss_batch(
430
- self, _
431
- ) -> tuple[Float[Array, " nb_norm_samples dimension"]]:
432
- return (self.norm_samples,) # type: ignore -> cannot narrow a class attr
433
-
434
- # we could have used typing.cast though
435
-
436
- def _get_observations_loss_batch(
437
545
  self, batch: PDEStatioBatch
438
- ) -> Float[Array, " batch_size obs_dim"]:
439
- return batch.obs_batch_dict["pinn_in"]
546
+ ) -> tuple[Float[Array, " nb_norm_samples dimension"] | None,]:
547
+ return (self.norm_samples,)
440
548
 
441
- def __call__(self, *args, **kwargs):
442
- return self.evaluate(*args, **kwargs)
549
+ # we could have used typing.cast though
443
550
 
444
551
  def evaluate_by_terms(
445
552
  self, params: Params[Array], batch: PDEStatioBatch
446
- ) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
553
+ ) -> tuple[
554
+ PDEStatioComponents[Float[Array, ""] | None],
555
+ PDEStatioComponents[Float[Array, ""] | None],
556
+ ]:
447
557
  """
448
558
  Evaluate the loss function at a batch of points for given parameters.
449
559
 
@@ -465,69 +575,23 @@ class LossPDEStatio(_LossPDEAbstract):
465
575
  # and update vmap_in_axes
466
576
  if batch.param_batch_dict is not None:
467
577
  # update eq_params with the batches of generated params
468
- params = _update_eq_params_dict(params, batch.param_batch_dict)
578
+ params = update_eq_params(params, batch.param_batch_dict)
469
579
 
470
580
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
471
581
 
472
582
  # dynamic part
473
- if self.dynamic_loss is not None:
474
- dyn_loss_fun = lambda p: dynamic_loss_apply(
475
- self.dynamic_loss.evaluate, # type: ignore
476
- self.u,
477
- self._get_dynamic_loss_batch(batch),
478
- _set_derivatives(p, self.derivative_keys.dyn_loss), # type: ignore
479
- self.vmap_in_axes + vmap_in_axes_params,
480
- )
481
- else:
482
- dyn_loss_fun = None
583
+ dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
483
584
 
484
585
  # normalization part
485
- if self.norm_samples is not None:
486
- norm_loss_fun = lambda p: normalization_loss_apply(
487
- self.u,
488
- self._get_normalization_loss_batch(batch),
489
- _set_derivatives(p, self.derivative_keys.norm_loss), # type: ignore
490
- vmap_in_axes_params,
491
- self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
492
- )
493
- else:
494
- norm_loss_fun = None
586
+ norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
495
587
 
496
588
  # boundary part
497
- if (
498
- self.omega_boundary_condition is not None
499
- and self.omega_boundary_dim is not None
500
- and self.omega_boundary_fun is not None
501
- ): # pyright cannot narrow down the three None otherwise as it is class attribute
502
- boundary_loss_fun = lambda p: boundary_condition_apply(
503
- self.u,
504
- batch,
505
- _set_derivatives(p, self.derivative_keys.boundary_loss), # type: ignore
506
- self.omega_boundary_fun, # type: ignore
507
- self.omega_boundary_condition, # type: ignore
508
- self.omega_boundary_dim, # type: ignore
509
- )
510
- else:
511
- boundary_loss_fun = None
589
+ boundary_loss_fun = self._get_boundary_loss_fun(batch)
512
590
 
513
591
  # Observation mse
514
- if batch.obs_batch_dict is not None:
515
- # update params with the batches of observed params
516
- params_obs = _update_eq_params_dict(
517
- params, batch.obs_batch_dict["eq_params"]
518
- )
519
-
520
- obs_loss_fun = lambda po: observations_loss_apply(
521
- self.u,
522
- self._get_observations_loss_batch(batch),
523
- _set_derivatives(po, self.derivative_keys.observations), # type: ignore
524
- self.vmap_in_axes + vmap_in_axes_params,
525
- batch.obs_batch_dict["val"],
526
- self.obs_slice,
527
- )
528
- else:
529
- params_obs = None
530
- obs_loss_fun = None
592
+ params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
593
+ batch, vmap_in_axes_params, params
594
+ )
531
595
 
532
596
  # get the unweighted mses for each loss term as well as the gradients
533
597
  all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
@@ -539,47 +603,34 @@ class LossPDEStatio(_LossPDEAbstract):
539
603
  params, params, params, params_obs
540
604
  )
541
605
  mses_grads = jax.tree.map(
542
- lambda fun, params: self.get_gradients(fun, params),
606
+ self.get_gradients,
543
607
  all_funs,
544
608
  all_params,
545
609
  is_leaf=lambda x: x is None,
546
610
  )
547
611
  mses = jax.tree.map(
548
- lambda leaf: leaf[0], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
612
+ lambda leaf: leaf[0], # type: ignore
613
+ mses_grads,
614
+ is_leaf=lambda x: isinstance(x, tuple),
549
615
  )
550
616
  grads = jax.tree.map(
551
- lambda leaf: leaf[1], mses_grads, is_leaf=lambda x: isinstance(x, tuple)
617
+ lambda leaf: leaf[1], # type: ignore
618
+ mses_grads,
619
+ is_leaf=lambda x: isinstance(x, tuple),
552
620
  )
553
621
 
554
622
  return mses, grads
555
623
 
556
- def evaluate(
557
- self, params: Params[Array], batch: PDEStatioBatch
558
- ) -> tuple[Float[Array, " "], PDEStatioComponents[Float[Array, " "] | None]]:
559
- """
560
- Evaluate the loss function at a batch of points for given parameters.
561
-
562
- We retrieve the total value itself and a PyTree with loss values for each term
563
-
564
- Parameters
565
- ---------
566
- params
567
- Parameters at which the loss is evaluated
568
- batch
569
- Composed of a batch of points in the
570
- domain, a batch of points in the domain
571
- border and an optional additional batch of parameters (eg. for
572
- metamodeling) and an optional additional batch of observed
573
- inputs/outputs/parameters
574
- """
575
- loss_terms, _ = self.evaluate_by_terms(params, batch)
576
-
577
- loss_val = self.ponderate_and_sum_loss(loss_terms)
578
-
579
- return loss_val, loss_terms
580
624
 
581
-
582
- class LossPDENonStatio(LossPDEStatio):
625
+ class LossPDENonStatio(
626
+ _LossPDEAbstract[
627
+ LossWeightsPDENonStatio,
628
+ PDENonStatioBatch,
629
+ PDENonStatioComponents[Array | None],
630
+ DerivativeKeysPDENonStatio,
631
+ PDENonStatio | None,
632
+ ]
633
+ ):
583
634
  r"""Loss object for a stationary partial differential equation
584
635
 
585
636
  $$
@@ -603,7 +654,7 @@ class LossPDENonStatio(LossPDEStatio):
603
654
  `dynamic_loss.evaluate(t, x, u, params)`.
604
655
  Can be None in order to access only some part of the evaluate call
605
656
  results.
606
- key : Key
657
+ key : PRNGKeyArray
607
658
  A JAX PRNG Key for the loss class treated as an attribute. Default is
608
659
  None. This field is provided for future developments and additional
609
660
  losses that might need some randomness. Note that special care must be
@@ -616,14 +667,31 @@ class LossPDENonStatio(LossPDEStatio):
616
667
  observations if any.
617
668
  Can be updated according to a specific algorithm. See
618
669
  `update_weight_method`
619
- update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
620
- Default is None meaning no update for loss weights. Otherwise a string
621
670
  derivative_keys : DerivativeKeysPDENonStatio, default=None
622
671
  Specify which field of `params` should be differentiated for each
623
672
  composant of the total loss. Particularily useful for inverse problems.
624
673
  Fields can be "nn_params", "eq_params" or "both". Those that should not
625
674
  be updated will have a `jax.lax.stop_gradient` called on them. Default
626
675
  is `"nn_params"` for each composant of the loss.
676
+ initial_condition_fun : Callable, default=None
677
+ A function representing the initial condition at `t0`. If None
678
+ (default) then no initial condition is applied.
679
+ t0 : float | Float[Array, " 1"], default=None
680
+ The time at which to apply the initial condition. If None, the time
681
+ is set to `0` by default.
682
+ max_norm_time_slices : int, default=100
683
+ The maximum number of time points in the Cartesian product with the
684
+ omega points to create the set of collocation points upon which the
685
+ normalization constant is computed.
686
+ max_norm_samples_omega : int, default=1000
687
+ The maximum number of omega points in the Cartesian product with the
688
+ time points to create the set of collocation points upon which the
689
+ normalization constant is computed.
690
+ params : InitVar[Params[Array]], default=None
691
+ The main `Params` object of the problem needed to instanciate the
692
+ `DerivativeKeysODE` if the latter is not specified.
693
+ update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
694
+ Default is None meaning no update for loss weights. Otherwise a string
627
695
  omega_boundary_fun : BoundaryConditionFun | dict[str, BoundaryConditionFun], default=None
628
696
  The function to be matched in the border condition (can be None) or a
629
697
  dictionary of such functions as values and keys as described
@@ -660,58 +728,81 @@ class LossPDENonStatio(LossPDEStatio):
660
728
  obs_slice : slice, default=None
661
729
  slice object specifying the begininning/ending of the PINN output
662
730
  that is observed (this is then useful for multidim PINN). Default is None.
663
- t0 : float | Float[Array, " 1"], default=None
664
- The time at which to apply the initial condition. If None, the time
665
- is set to `0` by default.
666
- initial_condition_fun : Callable, default=None
667
- A function representing the initial condition at `t0`. If None
668
- (default) then no initial condition is applied.
669
- params : InitVar[Params[Array]], default=None
670
- The main `Params` object of the problem needed to instanciate the
671
- `DerivativeKeysODE` if the latter is not specified.
672
731
 
673
732
  """
674
733
 
734
+ u: AbstractPINN
675
735
  dynamic_loss: PDENonStatio | None
676
- # NOTE static=True only for leaf attributes that are not valid JAX types
677
- # (ie. jax.Array cannot be static) and that we do not expect to change
678
- initial_condition_fun: Callable | None = eqx.field(
679
- kw_only=True, default=None, static=True
736
+ loss_weights: LossWeightsPDENonStatio
737
+ derivative_keys: DerivativeKeysPDENonStatio
738
+ params: InitVar[Params[Array] | None]
739
+ t0: Float[Array, " "]
740
+ initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
741
+ eqx.field(static=True)
680
742
  )
681
- t0: float | Float[Array, " 1"] | None = eqx.field(kw_only=True, default=None)
743
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
744
+ max_norm_samples_omega: int = eqx.field(static=True)
745
+ max_norm_time_slices: int = eqx.field(static=True)
746
+
747
+ params: InitVar[Params[Array] | None]
748
+
749
+ def __init__(
750
+ self,
751
+ *,
752
+ u: AbstractPINN,
753
+ dynamic_loss: PDENonStatio | None,
754
+ loss_weights: LossWeightsPDENonStatio | None = None,
755
+ derivative_keys: DerivativeKeysPDENonStatio | None = None,
756
+ initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
757
+ | None = None,
758
+ t0: int | float | Float[Array, " "] | None = None,
759
+ max_norm_time_slices: int = 100,
760
+ max_norm_samples_omega: int = 1000,
761
+ params: Params[Array] | None = None,
762
+ **kwargs: Any,
763
+ ):
764
+ self.u = u
765
+ if loss_weights is None:
766
+ self.loss_weights = LossWeightsPDENonStatio()
767
+ else:
768
+ self.loss_weights = loss_weights
769
+ self.dynamic_loss = dynamic_loss
682
770
 
683
- _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
684
- _max_norm_time_slices: Int = eqx.field(init=False, static=True)
771
+ super().__init__(
772
+ **kwargs,
773
+ )
685
774
 
686
- def __post_init__(self, params=None):
687
- """
688
- Note that neither __init__ or __post_init__ are called when udating a
689
- Module with eqx.tree_at!
690
- """
691
- super().__post_init__(
692
- params=params
693
- ) # because __init__ or __post_init__ of Base
694
- # class is not automatically called
775
+ if derivative_keys is None:
776
+ # be default we only take gradient wrt nn_params
777
+ try:
778
+ self.derivative_keys = DerivativeKeysPDENonStatio(params=params)
779
+ except ValueError as exc:
780
+ raise ValueError(
781
+ "Problem at derivative_keys initialization "
782
+ f"received {derivative_keys=} and {params=}"
783
+ ) from exc
784
+ else:
785
+ self.derivative_keys = derivative_keys
695
786
 
696
787
  self.vmap_in_axes = (0,) # for t_x
697
788
 
698
- if self.initial_condition_fun is None:
789
+ if initial_condition_fun is None:
699
790
  warnings.warn(
700
791
  "Initial condition wasn't provided. Be sure to cover for that"
701
792
  "case (e.g by. hardcoding it into the PINN output)."
702
793
  )
703
794
  # some checks for t0
704
- t0 = self.t0
705
795
  if t0 is None:
706
- t0 = jnp.array([0])
796
+ self.t0 = jnp.array([0])
707
797
  else:
708
- t0 = initial_condition_check(t0, dim_size=1)
709
- self.t0 = t0
798
+ self.t0 = initial_condition_check(t0, dim_size=1)
799
+
800
+ self.initial_condition_fun = initial_condition_fun
710
801
 
711
- # witht the variables below we avoid memory overflow since a cartesian
802
+ # with the variables below we avoid memory overflow since a cartesian
712
803
  # product is taken
713
- self._max_norm_time_slices = 100
714
- self._max_norm_samples_omega = 1000
804
+ self.max_norm_time_slices = max_norm_time_slices
805
+ self.max_norm_samples_omega = max_norm_samples_omega
715
806
 
716
807
  def _get_dynamic_loss_batch(
717
808
  self, batch: PDENonStatioBatch
@@ -724,18 +815,10 @@ class LossPDENonStatio(LossPDEStatio):
724
815
  Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
725
816
  ]:
726
817
  return (
727
- batch.domain_batch[: self._max_norm_time_slices, 0:1],
728
- self.norm_samples[: self._max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
818
+ batch.domain_batch[: self.max_norm_time_slices, 0:1],
819
+ self.norm_samples[: self.max_norm_samples_omega], # type: ignore -> cannot narrow a class attr
729
820
  )
730
821
 
731
- def _get_observations_loss_batch(
732
- self, batch: PDENonStatioBatch
733
- ) -> Float[Array, " batch_size 1+dim"]:
734
- return batch.obs_batch_dict["pinn_in"]
735
-
736
- def __call__(self, *args, **kwargs):
737
- return self.evaluate(*args, **kwargs)
738
-
739
822
  def evaluate_by_terms(
740
823
  self, params: Params[Array], batch: PDENonStatioBatch
741
824
  ) -> tuple[
@@ -757,76 +840,75 @@ class LossPDENonStatio(LossPDEStatio):
757
840
  metamodeling) and an optional additional batch of observed
758
841
  inputs/outputs/parameters
759
842
  """
760
- omega_batch = batch.initial_batch
761
- assert omega_batch is not None
843
+ omega_initial_batch = batch.initial_batch
844
+ assert omega_initial_batch is not None
762
845
 
763
846
  # Retrieve the optional eq_params_batch
764
847
  # and update eq_params with the latter
765
848
  # and update vmap_in_axes
766
849
  if batch.param_batch_dict is not None:
767
850
  # update eq_params with the batches of generated params
768
- params = _update_eq_params_dict(params, batch.param_batch_dict)
851
+ params = update_eq_params(params, batch.param_batch_dict)
769
852
 
770
853
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
771
854
 
772
- # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
773
- # mse_observation_loss we use the evaluate from parent class
774
- # As well as for their gradients
775
- partial_mses, partial_grads = super().evaluate_by_terms(params, batch) # type: ignore
776
- # ignore because batch is not PDEStatioBatch. We could use typing.cast though
855
+ # dynamic part
856
+ dyn_loss_fun = self._get_dyn_loss_fun(batch, vmap_in_axes_params)
857
+
858
+ # normalization part
859
+ norm_loss_fun = self._get_norm_loss_fun(batch, vmap_in_axes_params)
860
+
861
+ # boundary part
862
+ boundary_loss_fun = self._get_boundary_loss_fun(batch)
863
+
864
+ # Observation mse
865
+ params_obs, obs_loss_fun = self._get_obs_params_and_obs_loss_fun(
866
+ batch, vmap_in_axes_params, params
867
+ )
777
868
 
778
869
  # initial condition
779
870
  if self.initial_condition_fun is not None:
780
- mse_initial_condition_fun = lambda p: initial_condition_apply(
781
- self.u,
782
- omega_batch,
783
- _set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
784
- (0,) + vmap_in_axes_params,
785
- self.initial_condition_fun, # type: ignore
786
- self.t0, # type: ignore can't get the narrowing in __post_init__
787
- )
788
- mse_initial_condition, grad_initial_condition = self.get_gradients(
789
- mse_initial_condition_fun, params
871
+ mse_initial_condition_fun: Callable[[Params[Array]], Array] | None = (
872
+ lambda p: initial_condition_apply(
873
+ self.u,
874
+ omega_initial_batch,
875
+ _set_derivatives(p, self.derivative_keys.initial_condition),
876
+ (0,) + vmap_in_axes_params,
877
+ self.initial_condition_fun, # type: ignore
878
+ self.t0,
879
+ )
790
880
  )
791
881
  else:
792
- mse_initial_condition = None
793
- grad_initial_condition = None
794
-
795
- mses = PDENonStatioComponents(
796
- partial_mses.dyn_loss,
797
- partial_mses.norm_loss,
798
- partial_mses.boundary_loss,
799
- partial_mses.observations,
800
- mse_initial_condition,
801
- )
882
+ mse_initial_condition_fun = None
802
883
 
803
- grads = PDENonStatioComponents(
804
- partial_grads.dyn_loss,
805
- partial_grads.norm_loss,
806
- partial_grads.boundary_loss,
807
- partial_grads.observations,
808
- grad_initial_condition,
884
+ # get the unweighted mses for each loss term as well as the gradients
885
+ all_funs: PDENonStatioComponents[Callable[[Params[Array]], Array] | None] = (
886
+ PDENonStatioComponents(
887
+ dyn_loss_fun,
888
+ norm_loss_fun,
889
+ boundary_loss_fun,
890
+ obs_loss_fun,
891
+ mse_initial_condition_fun,
892
+ )
893
+ )
894
+ all_params: PDENonStatioComponents[Params[Array] | None] = (
895
+ PDENonStatioComponents(params, params, params, params_obs, params)
896
+ )
897
+ mses_grads = jax.tree.map(
898
+ self.get_gradients,
899
+ all_funs,
900
+ all_params,
901
+ is_leaf=lambda x: x is None,
902
+ )
903
+ mses = jax.tree.map(
904
+ lambda leaf: leaf[0], # type: ignore
905
+ mses_grads,
906
+ is_leaf=lambda x: isinstance(x, tuple),
907
+ )
908
+ grads = jax.tree.map(
909
+ lambda leaf: leaf[1], # type: ignore
910
+ mses_grads,
911
+ is_leaf=lambda x: isinstance(x, tuple),
809
912
  )
810
913
 
811
914
  return mses, grads
812
-
813
- def evaluate(
814
- self, params: Params[Array], batch: PDENonStatioBatch
815
- ) -> tuple[Float[Array, " "], PDENonStatioComponents[Float[Array, " "] | None]]:
816
- """
817
- Evaluate the loss function at a batch of points for given parameters.
818
- We retrieve the total value itself and a PyTree with loss values for each term
819
-
820
-
821
- Parameters
822
- ---------
823
- params
824
- Parameters at which the loss is evaluated
825
- batch
826
- Composed of a batch of points in
827
- the domain, a batch of points in the domain
828
- border, a batch of time points and an optional additional batch
829
- of parameters (eg. for metamodeling) and an optional additional batch of observed
830
- inputs/outputs/parameters
831
- """
832
- return super().evaluate(params, batch) # type: ignore