jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py CHANGED
@@ -1,29 +1,52 @@
1
+ # pylint: disable=unsubscriptable-object, no-member
1
2
  """
2
3
  Main module to implement a PDE loss in jinns
3
4
  """
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
4
8
 
9
+ import abc
10
+ from dataclasses import InitVar, fields
11
+ from typing import TYPE_CHECKING, Dict, Callable
5
12
  import warnings
6
13
  import jax
7
14
  import jax.numpy as jnp
8
- from jax.tree_util import register_pytree_node_class
9
- from jinns.loss._Losses import (
15
+ import equinox as eqx
16
+ from jaxtyping import Float, Array, Key, Int
17
+ from jinns.loss._loss_utils import (
10
18
  dynamic_loss_apply,
11
19
  boundary_condition_apply,
12
20
  normalization_loss_apply,
13
21
  observations_loss_apply,
14
- sobolev_reg_apply,
15
22
  initial_condition_apply,
16
23
  constraints_system_loss_apply,
17
24
  )
18
- from jinns.data._DataGenerators import PDEStatioBatch, PDENonStatioBatch
19
- from jinns.utils._utils import (
25
+ from jinns.data._DataGenerators import (
26
+ append_obs_batch,
27
+ )
28
+ from jinns.parameters._params import (
20
29
  _get_vmap_in_axes_params,
21
- _set_derivatives,
22
30
  _update_eq_params_dict,
23
31
  )
32
+ from jinns.parameters._derivative_keys import (
33
+ _set_derivatives,
34
+ DerivativeKeysPDEStatio,
35
+ DerivativeKeysPDENonStatio,
36
+ )
37
+ from jinns.loss._loss_weights import (
38
+ LossWeightsPDEStatio,
39
+ LossWeightsPDENonStatio,
40
+ LossWeightsPDEDict,
41
+ )
42
+ from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
24
43
  from jinns.utils._pinn import PINN
25
44
  from jinns.utils._spinn import SPINN
26
- from jinns.loss._operators import _sobolev
45
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
46
+
47
+
48
+ if TYPE_CHECKING:
49
+ from jinns.utils._types import *
27
50
 
28
51
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
29
52
  "dirichlet",
@@ -31,375 +54,167 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
31
54
  "vonneumann",
32
55
  ]
33
56
 
34
- _LOSS_WEIGHT_KEYS_PDESTATIO = [
35
- "sobolev",
36
- "observations",
37
- "norm_loss",
38
- "boundary_loss",
39
- "dyn_loss",
40
- ]
41
-
42
- _LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
43
57
 
44
-
45
- @register_pytree_node_class
46
- class LossPDEAbstract:
58
+ class _LossPDEAbstract(eqx.Module):
47
59
  """
48
- Super class for the actual Pinn loss classes. This class should not be
49
- used. It serves for common attributes between LossPDEStatio and
50
- LossPDENonStatio
51
-
52
-
53
- **Note:** LossPDEAbstract is jittable. Hence it implements the tree_flatten() and
54
- tree_unflatten methods.
60
+ Parameters
61
+ ----------
62
+
63
+ loss_weights : LossWeightsPDEStatio | LossWeightsPDENonStatio, default=None
64
+ The loss weights for the differents term : dynamic loss,
65
+ initial condition (if LossWeightsPDENonStatio), boundary conditions if
66
+ any, normalization loss if any and observations if any.
67
+ All fields are set to 1.0 by default.
68
+ derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
69
+ Specify which field of `params` should be differentiated for each
70
+ composant of the total loss. Particularily useful for inverse problems.
71
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
72
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
73
+ is `"nn_params"` for each composant of the loss.
74
+ omega_boundary_fun : Callable | Dict[str, Callable], default=None
75
+ The function to be matched in the border condition (can be None) or a
76
+ dictionary of such functions as values and keys as described
77
+ in `omega_boundary_condition`.
78
+ omega_boundary_condition : str | Dict[str, str], default=None
79
+ Either None (no condition, by default), or a string defining
80
+ the boundary condition (Dirichlet or Von Neumann),
81
+ or a dictionary with such strings as values. In this case,
82
+ the keys are the facets and must be in the following order:
83
+ 1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
84
+ Note that high order boundaries are currently not implemented.
85
+ A value in the dict can be None, this means we do not enforce
86
+ a particular boundary condition on this facet.
87
+ The facet called “xmin”, resp. “xmax” etc., in 2D,
88
+ refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
89
+ omega_boundary_dim : slice | Dict[str, slice], default=None
90
+ Either None, or a slice object or a dictionary of slice objects as
91
+ values and keys as described in `omega_boundary_condition`.
92
+ `omega_boundary_dim` indicates which dimension(s) of the PINN
93
+ will be forced to match the boundary condition.
94
+ Note that it must be a slice and not an integer
95
+ (but a preprocessing of the user provided argument takes care of it)
96
+ norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
97
+ Fixed sample point in the space over which to compute the
98
+ normalization constant. Default is None.
99
+ norm_int_length : float, default=None
100
+ A float. Must be provided if `norm_samples` is provided. The domain area
101
+ (or interval length in 1D) upon which we perform the numerical
102
+ integration. Default None
103
+ obs_slice : slice, default=None
104
+ slice object specifying the begininning/ending of the PINN output
105
+ that is observed (this is then useful for multidim PINN). Default is None.
106
+ params : InitVar[Params], default=None
107
+ The main Params object of the problem needed to instanciate the
108
+ DerivativeKeysODE if the latter is not specified.
55
109
  """
56
110
 
57
- def __init__(
58
- self,
59
- u,
60
- loss_weights,
61
- derivative_keys=None,
62
- norm_key=None,
63
- norm_borders=None,
64
- norm_samples=None,
65
- ):
111
+ # NOTE static=True only for leaf attributes that are not valid JAX types
112
+ # (ie. jax.Array cannot be static) and that we do not expect to change
113
+ # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
114
+ derivative_keys: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None = (
115
+ eqx.field(kw_only=True, default=None)
116
+ )
117
+ loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
118
+ kw_only=True, default=None
119
+ )
120
+ omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
121
+ kw_only=True, default=None, static=True
122
+ )
123
+ omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
124
+ kw_only=True, default=None, static=True
125
+ )
126
+ omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
127
+ kw_only=True, default=None, static=True
128
+ )
129
+ norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
130
+ kw_only=True, default=None
131
+ )
132
+ norm_int_length: float | None = eqx.field(kw_only=True, default=None)
133
+ obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
134
+
135
+ params: InitVar[Params] = eqx.field(kw_only=True, default=None)
136
+
137
+ def __post_init__(self, params=None):
66
138
  """
67
- Parameters
68
- ----------
69
- u
70
- the PINN object
71
- loss_weights
72
- a dictionary with values used to ponderate each term in the loss
73
- function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`
74
- and `observations`
75
- Note that we can have jnp.arrays with the same dimension of
76
- `u` which then ponderates each output of `u`
77
- derivative_keys
78
- A dict of lists of strings. In the dict, the key must correspond to
79
- the loss term keywords. Then each of the values must correspond to keys in the parameter
80
- dictionary (*at top level only of the parameter dictionary*).
81
- It enables selecting the set of parameters
82
- with respect to which the gradients of the dynamic
83
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
84
- keywords, this is what is typically
85
- done in solving forward problems, when we only estimate the
86
- equation solution with a PINN. If some loss terms keywords are
87
- missing we set their value to ["nn_params"] by default for the
88
- same reason
89
- norm_key
90
- Jax random key to draw samples in for the Monte Carlo computation
91
- of the normalization constant. Default is None
92
- norm_borders
93
- tuple of (min, max) of the boundaray values of the space over which
94
- to integrate in the computation of the normalization constant.
95
- A list of tuple for higher dimensional problems. Default None.
96
- norm_samples
97
- Fixed sample point in the space over which to compute the
98
- normalization constant. Default is None
99
-
100
- Raises
101
- ------
102
- RuntimeError
103
- When provided an invalid combination of `norm_key`, `norm_borders`
104
- and `norm_samples`. See note below.
105
-
106
- **Note:** If `norm_key` and `norm_borders` and `norm_samples` are `None`
107
- then no normalization loss in enforced.
108
- If `norm_borders` and `norm_samples` are given while
109
- `norm_samples` is `None` then samples are drawn at each loss evaluation.
110
- Otherwise, if `norm_samples` is given, those samples are used.
139
+ Note that neither __init__ or __post_init__ are called when udating a
140
+ Module with eqx.tree_at
111
141
  """
112
-
113
- self.u = u
114
- if derivative_keys is None:
142
+ if self.derivative_keys is None:
115
143
  # be default we only take gradient wrt nn_params
116
- derivative_keys = {
117
- k: ["nn_params"]
118
- for k in [
119
- "dyn_loss",
120
- "boundary_loss",
121
- "norm_loss",
122
- "initial_condition",
123
- "observations",
124
- "sobolev",
125
- ]
126
- }
127
- if isinstance(derivative_keys, list):
128
- # if the user only provided a list, this defines the gradient taken
129
- # for all the loss entries
130
- derivative_keys = {
131
- k: derivative_keys
132
- for k in [
133
- "dyn_loss",
134
- "boundary_loss",
135
- "norm_loss",
136
- "initial_condition",
137
- "observations",
138
- "sobolev",
139
- ]
140
- }
141
-
142
- self.derivative_keys = derivative_keys
143
- self.loss_weights = loss_weights
144
- self.norm_borders = norm_borders
145
- self.norm_key = norm_key
146
- self.norm_samples = norm_samples
147
-
148
- if norm_key is None and norm_borders is None and norm_samples is None:
149
- # if there is None of the 3 above, that means we don't consider
150
- # normalization loss
151
- self.normalization_loss = None
152
- elif (
153
- norm_key is not None and norm_borders is not None and norm_samples is None
154
- ): # this ordering so that by default priority is to given mc_samples
155
- self.norm_sample_method = "generate"
156
- if not isinstance(self.norm_borders[0], tuple):
157
- self.norm_borders = (self.norm_borders,)
158
- self.norm_xmin, self.norm_xmax = [], []
159
- for i, _ in enumerate(self.norm_borders):
160
- self.norm_xmin.append(self.norm_borders[i][0])
161
- self.norm_xmax.append(self.norm_borders[i][1])
162
- self.int_length = jnp.prod(
163
- jnp.array(
164
- [
165
- self.norm_xmax[i] - self.norm_xmin[i]
166
- for i in range(len(self.norm_borders))
167
- ]
168
- )
169
- )
170
- self.normalization_loss = True
171
- elif norm_samples is None:
172
- raise RuntimeError(
173
- "norm_borders should always provided then either"
174
- " norm_samples (fixed norm_samples) or norm_key (random norm_samples)"
175
- " is required."
176
- )
177
- else:
178
- # ok, we are sure we have norm_samples given by the user
179
- self.norm_sample_method = "user"
180
- if not isinstance(self.norm_borders[0], tuple):
181
- self.norm_borders = (self.norm_borders,)
182
- self.norm_xmin, self.norm_xmax = [], []
183
- for i, _ in enumerate(self.norm_borders):
184
- self.norm_xmin.append(self.norm_borders[i][0])
185
- self.norm_xmax.append(self.norm_borders[i][1])
186
- self.int_length = jnp.prod(
187
- jnp.array(
188
- [
189
- self.norm_xmax[i] - self.norm_xmin[i]
190
- for i in range(len(self.norm_borders))
191
- ]
144
+ try:
145
+ self.derivative_keys = (
146
+ DerivativeKeysPDENonStatio(params=params)
147
+ if isinstance(self, LossPDENonStatio)
148
+ else DerivativeKeysPDEStatio(params=params)
192
149
  )
150
+ except ValueError as exc:
151
+ raise ValueError(
152
+ "Problem at self.derivative_keys initialization "
153
+ f"received {self.derivative_keys=} and {params=}"
154
+ ) from exc
155
+
156
+ if self.loss_weights is None:
157
+ self.loss_weights = (
158
+ LossWeightsPDENonStatio()
159
+ if isinstance(self, LossPDENonStatio)
160
+ else LossWeightsPDEStatio()
193
161
  )
194
- self.normalization_loss = True
195
162
 
196
- def get_norm_samples(self):
197
- """
198
- Returns a batch of points in the domain for integration when the
199
- normalization constraint is enforced. The batch of points is either
200
- fixed (provided by the user) or regenerated at each iteration.
201
- """
202
- if self.norm_sample_method == "user":
203
- return self.norm_samples
204
- if self.norm_sample_method == "generate":
205
- ## NOTE TODO CHECK the performances of this for loop
206
- norm_samples = []
207
- for d in range(len(self.norm_borders)):
208
- self.norm_key, subkey = jax.random.split(self.norm_key)
209
- norm_samples.append(
210
- jax.random.uniform(
211
- subkey,
212
- shape=(1000, 1),
213
- minval=self.norm_xmin[d],
214
- maxval=self.norm_xmax[d],
215
- )
216
- )
217
- self.norm_samples = jnp.concatenate(norm_samples, axis=-1)
218
- return self.norm_samples
219
- raise RuntimeError("Problem with the value of self.norm_sample_method")
220
-
221
- def tree_flatten(self):
222
- children = (self.norm_key, self.norm_samples, self.loss_weights)
223
- aux_data = {
224
- "norm_borders": self.norm_borders,
225
- "derivative_keys": self.derivative_keys,
226
- "u": self.u,
227
- }
228
- return (children, aux_data)
229
-
230
- @classmethod
231
- def tree_unflatten(self, aux_data, children):
232
- (norm_key, norm_samples, loss_weights) = children
233
- pls = self(
234
- aux_data["u"],
235
- loss_weights,
236
- aux_data["derivative_keys"],
237
- norm_key,
238
- aux_data["norm_borders"],
239
- norm_samples,
240
- )
241
- return pls
242
-
243
-
244
- @register_pytree_node_class
245
- class LossPDEStatio(LossPDEAbstract):
246
- r"""Loss object for a stationary partial differential equation
247
-
248
- .. math::
249
- \mathcal{N}[u](x) = 0, \forall x \in \Omega
250
-
251
- where :math:`\mathcal{N}[\cdot]` is a differential operator and the
252
- boundary condition is :math:`u(x)=u_b(x)` The additional condition of
253
- integrating to 1 can be included, i.e. :math:`\int u(x)\mathrm{d}x=1`.
254
-
255
-
256
- **Note:** LossPDEStatio is jittable. Hence it implements the tree_flatten() and
257
- tree_unflatten methods.
258
- """
259
-
260
- def __init__(
261
- self,
262
- u,
263
- loss_weights,
264
- dynamic_loss,
265
- derivative_keys=None,
266
- omega_boundary_fun=None,
267
- omega_boundary_condition=None,
268
- omega_boundary_dim=None,
269
- norm_key=None,
270
- norm_borders=None,
271
- norm_samples=None,
272
- sobolev_m=None,
273
- obs_slice=None,
274
- ):
275
- r"""
276
- Parameters
277
- ----------
278
- u
279
- the PINN object
280
- loss_weights
281
- a dictionary with values used to ponderate each term in the loss
282
- function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
283
- `observations` and `sobolev`.
284
- Note that we can have jnp.arrays with the same dimension of
285
- `u` which then ponderates each output of `u`
286
- dynamic_loss
287
- the stationary PDE dynamic part of the loss, basically the differential
288
- operator :math:` \mathcal{N}[u](t)`. Should implement a method
289
- `dynamic_loss.evaluate(t, u, params)`.
290
- Can be None in order to access only some part of the evaluate call
291
- results.
292
- derivative_keys
293
- A dict of lists of strings. In the dict, the key must correspond to
294
- the loss term keywords. Then each of the values must correspond to keys in the parameter
295
- dictionary (*at top level only of the parameter dictionary*).
296
- It enables selecting the set of parameters
297
- with respect to which the gradients of the dynamic
298
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
299
- keywords, this is what is typically
300
- done in solving forward problems, when we only estimate the
301
- equation solution with a PINN. If some loss terms keywords are
302
- missing we set their value to ["nn_params"] by default for the same
303
- reason
304
- omega_boundary_fun
305
- The function to be matched in the border condition (can be None)
306
- or a dictionary of such function. In this case, the keys are the
307
- facets and the values are the functions. The keys must be in the
308
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
309
- "ymin", "ymax"]. Note that high order boundaries are currently not
310
- implemented. A value in the dict can be None, this means we do not
311
- enforce a particular boundary condition on this facet.
312
- The facet called "xmin", resp. "xmax" etc., in 2D,
313
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
314
- omega_boundary_condition
315
- Either None (no condition), or a string defining the boundary
316
- condition e.g. Dirichlet or Von Neumann, or a dictionary of such
317
- strings. In this case, the keys are the
318
- facets and the values are the strings. The keys must be in the
319
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
320
- "ymin", "ymax"]. Note that high order boundaries are currently not
321
- implemented. A value in the dict can be None, this means we do not
322
- enforce a particular boundary condition on this facet.
323
- The facet called "xmin", resp. "xmax" etc., in 2D,
324
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
325
- omega_boundary_dim
326
- Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
327
- the logic of omega_boundary_fun. It indicates which dimension(s) of
328
- the PINN will be forced to match the boundary condition
329
- Note that it must be a slice and not an integer (a preprocessing of the
330
- user provided argument takes care of it)
331
- norm_key
332
- Jax random key to draw samples in for the Monte Carlo computation
333
- of the normalization constant. Default is None
334
- norm_borders
335
- tuple of (min, max) of the boundaray values of the space over which
336
- to integrate in the computation of the normalization constant.
337
- A list of tuple for higher dimensional problems. Default None.
338
- norm_samples
339
- Fixed sample point in the space over which to compute the
340
- normalization constant. Default is None
341
- sobolev_m
342
- An integer. Default is None.
343
- It corresponds to the Sobolev regularization order as proposed in
344
- *Convergence and error analysis of PINNs*,
345
- Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
346
- obs_slice
347
- slice object specifying the begininning/ending
348
- slice of u output(s) that is observed (this is then useful for
349
- multidim PINN). Default is None.
350
-
351
-
352
- Raises
353
- ------
354
- ValueError
355
- If conditions on omega_boundary_condition and omega_boundary_fun
356
- are not respected
357
- """
163
+ if self.obs_slice is None:
164
+ self.obs_slice = jnp.s_[...]
358
165
 
359
- super().__init__(
360
- u, loss_weights, derivative_keys, norm_key, norm_borders, norm_samples
361
- )
166
+ if (
167
+ isinstance(self.omega_boundary_fun, dict)
168
+ and not isinstance(self.omega_boundary_condition, dict)
169
+ ) or (
170
+ not isinstance(self.omega_boundary_fun, dict)
171
+ and isinstance(self.omega_boundary_condition, dict)
172
+ ):
173
+ raise ValueError(
174
+ "if one of self.omega_boundary_fun or "
175
+ "self.omega_boundary_condition is dict, the other should be too."
176
+ )
362
177
 
363
- if omega_boundary_condition is None or omega_boundary_fun is None:
178
+ if self.omega_boundary_condition is None or self.omega_boundary_fun is None:
364
179
  warnings.warn(
365
180
  "Missing boundary function or no boundary condition."
366
181
  "Boundary function is thus ignored."
367
182
  )
368
183
  else:
369
- if isinstance(omega_boundary_condition, dict):
370
- for _, v in omega_boundary_condition.items():
184
+ if isinstance(self.omega_boundary_condition, dict):
185
+ for _, v in self.omega_boundary_condition.items():
371
186
  if v is not None and not any(
372
187
  v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
373
188
  ):
374
189
  raise NotImplementedError(
375
- f"The boundary condition {omega_boundary_condition} is not"
190
+ f"The boundary condition {self.omega_boundary_condition} is not"
376
191
  f"implemented yet. Try one of :"
377
192
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
378
193
  )
379
194
  else:
380
195
  if not any(
381
- omega_boundary_condition.lower() in s
196
+ self.omega_boundary_condition.lower() in s
382
197
  for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
383
198
  ):
384
199
  raise NotImplementedError(
385
- f"The boundary condition {omega_boundary_condition} is not"
200
+ f"The boundary condition {self.omega_boundary_condition} is not"
386
201
  f"implemented yet. Try one of :"
387
202
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
388
203
  )
389
- if isinstance(omega_boundary_fun, dict) and isinstance(
390
- omega_boundary_condition, dict
204
+ if isinstance(self.omega_boundary_fun, dict) and isinstance(
205
+ self.omega_boundary_condition, dict
391
206
  ):
392
207
  if (
393
208
  not (
394
- list(omega_boundary_fun.keys()) == ["xmin", "xmax"]
395
- and list(omega_boundary_condition.keys())
209
+ list(self.omega_boundary_fun.keys()) == ["xmin", "xmax"]
210
+ and list(self.omega_boundary_condition.keys())
396
211
  == ["xmin", "xmax"]
397
212
  )
398
213
  ) or (
399
214
  not (
400
- list(omega_boundary_fun.keys())
215
+ list(self.omega_boundary_fun.keys())
401
216
  == ["xmin", "xmax", "ymin", "ymax"]
402
- and list(omega_boundary_condition.keys())
217
+ and list(self.omega_boundary_condition.keys())
403
218
  == ["xmin", "xmax", "ymin", "ymax"]
404
219
  )
405
220
  ):
@@ -408,10 +223,6 @@ class LossPDEStatio(LossPDEAbstract):
408
223
  "boundary condition dictionaries is incorrect"
409
224
  )
410
225
 
411
- self.omega_boundary_fun = omega_boundary_fun
412
- self.omega_boundary_condition = omega_boundary_condition
413
-
414
- self.omega_boundary_dim = omega_boundary_dim
415
226
  if isinstance(self.omega_boundary_fun, dict):
416
227
  if self.omega_boundary_dim is None:
417
228
  self.omega_boundary_dim = {
@@ -440,44 +251,144 @@ class LossPDEStatio(LossPDEAbstract):
440
251
  self.omega_boundary_dim : self.omega_boundary_dim + 1
441
252
  ]
442
253
  if not isinstance(self.omega_boundary_dim, slice):
443
- raise ValueError("self.omega_boundary_dim must be a jnp.s_" " object")
254
+ raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
444
255
 
445
- self.dynamic_loss = dynamic_loss
256
+ if self.norm_samples is not None and self.norm_int_length is None:
257
+ raise ValueError("self.norm_samples and norm_int_length must be provided")
446
258
 
447
- self.sobolev_m = sobolev_m
448
- if self.sobolev_m is not None:
449
- self.sobolev_reg = _sobolev(
450
- self.u, self.sobolev_m
451
- ) # we return a function, that way
452
- # the order of sobolev_m is static and the conditional in the recursive
453
- # function is properly set
454
- else:
455
- self.sobolev_reg = None
259
+ @abc.abstractmethod
260
+ def evaluate(
261
+ self: eqx.Module,
262
+ params: Params,
263
+ batch: PDEStatioBatch | PDENonStatioBatch,
264
+ ) -> tuple[Float, dict]:
265
+ raise NotImplementedError
456
266
 
457
- for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
458
- if k not in self.loss_weights.keys():
459
- self.loss_weights[k] = 0
460
267
 
461
- if (
462
- isinstance(self.omega_boundary_fun, dict)
463
- and not isinstance(self.omega_boundary_condition, dict)
464
- ) or (
465
- not isinstance(self.omega_boundary_fun, dict)
466
- and isinstance(self.omega_boundary_condition, dict)
467
- ):
468
- raise ValueError(
469
- "if one of self.omega_boundary_fun or "
470
- "self.omega_boundary_condition is dict, the other should be too."
471
- )
268
+ class LossPDEStatio(_LossPDEAbstract):
269
+ r"""Loss object for a stationary partial differential equation
472
270
 
473
- self.obs_slice = obs_slice
474
- if self.obs_slice is None:
475
- self.obs_slice = jnp.s_[...]
271
+ $$
272
+ \mathcal{N}[u](x) = 0, \forall x \in \Omega
273
+ $$
274
+
275
+ where $\mathcal{N}[\cdot]$ is a differential operator and the
276
+ boundary condition is $u(x)=u_b(x)$ The additional condition of
277
+ integrating to 1 can be included, i.e. $\int u(x)\mathrm{d}x=1$.
278
+
279
+ Parameters
280
+ ----------
281
+ u : eqx.Module
282
+ the PINN
283
+ dynamic_loss : DynamicLoss
284
+ the stationary PDE dynamic part of the loss, basically the differential
285
+ operator $\mathcal{N}[u](x)$. Should implement a method
286
+ `dynamic_loss.evaluate(x, u, params)`.
287
+ Can be None in order to access only some part of the evaluate call
288
+ results.
289
+ key : Key
290
+ A JAX PRNG Key for the loss class treated as an attribute. Default is
291
+ None. This field is provided for future developments and additional
292
+ losses that might need some randomness. Note that special care must be
293
+ taken when splitting the key because in-place updates are forbidden in
294
+ eqx.Modules.
295
+ loss_weights : LossWeightsPDEStatio, default=None
296
+ The loss weights for the differents term : dynamic loss,
297
+ boundary conditions if any, normalization loss if any and
298
+ observations if any.
299
+ All fields are set to 1.0 by default.
300
+ derivative_keys : DerivativeKeysPDEStatio, default=None
301
+ Specify which field of `params` should be differentiated for each
302
+ composant of the total loss. Particularily useful for inverse problems.
303
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
304
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
305
+ is `"nn_params"` for each composant of the loss.
306
+ omega_boundary_fun : Callable | Dict[str, Callable], default=None
307
+ The function to be matched in the border condition (can be None) or a
308
+ dictionary of such functions as values and keys as described
309
+ in `omega_boundary_condition`.
310
+ omega_boundary_condition : str | Dict[str, str], default=None
311
+ Either None (no condition, by default), or a string defining
312
+ the boundary condition (Dirichlet or Von Neumann),
313
+ or a dictionary with such strings as values. In this case,
314
+ the keys are the facets and must be in the following order:
315
+ 1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
316
+ Note that high order boundaries are currently not implemented.
317
+ A value in the dict can be None, this means we do not enforce
318
+ a particular boundary condition on this facet.
319
+ The facet called “xmin”, resp. “xmax” etc., in 2D,
320
+ refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
321
+ omega_boundary_dim : slice | Dict[str, slice], default=None
322
+ Either None, or a slice object or a dictionary of slice objects as
323
+ values and keys as described in `omega_boundary_condition`.
324
+ `omega_boundary_dim` indicates which dimension(s) of the PINN
325
+ will be forced to match the boundary condition.
326
+ Note that it must be a slice and not an integer
327
+ (but a preprocessing of the user provided argument takes care of it)
328
+ norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
329
+ Fixed sample point in the space over which to compute the
330
+ normalization constant. Default is None.
331
+ norm_int_length : float, default=None
332
+ A float. Must be provided if `norm_samples` is provided. The domain area
333
+ (or interval length in 1D) upon which we perform the numerical
334
+ integration. Default None
335
+ obs_slice : slice, default=None
336
+ slice object specifying the begininning/ending of the PINN output
337
+ that is observed (this is then useful for multidim PINN). Default is None.
338
+ params : InitVar[Params], default=None
339
+ The main Params object of the problem needed to instanciate the
340
+ DerivativeKeysODE if the latter is not specified.
341
+
342
+
343
+ Raises
344
+ ------
345
+ ValueError
346
+ If conditions on omega_boundary_condition and omega_boundary_fun
347
+ are not respected
348
+ """
349
+
350
+ # NOTE static=True only for leaf attributes that are not valid JAX types
351
+ # (ie. jax.Array cannot be static) and that we do not expect to change
352
+
353
+ u: eqx.Module
354
+ dynamic_loss: DynamicLoss | None
355
+ key: Key | None = eqx.field(kw_only=True, default=None)
356
+
357
+ vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
358
+
359
+ def __post_init__(self, params=None):
360
+ """
361
+ Note that neither __init__ or __post_init__ are called when udating a
362
+ Module with eqx.tree_at!
363
+ """
364
+ super().__post_init__(
365
+ params=params
366
+ ) # because __init__ or __post_init__ of Base
367
+ # class is not automatically called
368
+
369
+ self.vmap_in_axes = (0,) # for x only here
370
+
371
+ def _get_dynamic_loss_batch(
372
+ self, batch: PDEStatioBatch
373
+ ) -> tuple[Float[Array, "batch_size dimension"]]:
374
+ return (batch.inside_batch,)
375
+
376
+ def _get_normalization_loss_batch(
377
+ self, _
378
+ ) -> Float[Array, "nb_norm_samples dimension"]:
379
+ return (self.norm_samples,)
380
+
381
+ def _get_observations_loss_batch(
382
+ self, batch: PDEStatioBatch
383
+ ) -> Float[Array, "batch_size obs_dim"]:
384
+ return (batch.obs_batch_dict["pinn_in"],)
476
385
 
477
386
  def __call__(self, *args, **kwargs):
478
387
  return self.evaluate(*args, **kwargs)
479
388
 
480
- def evaluate(self, params, batch):
389
+ def evaluate(
390
+ self, params: Params, batch: PDEStatioBatch
391
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
481
392
  """
482
393
  Evaluate the loss function at a batch of points for given parameters.
483
394
 
@@ -485,22 +396,14 @@ class LossPDEStatio(LossPDEAbstract):
485
396
  Parameters
486
397
  ---------
487
398
  params
488
- The dictionary of parameters of the model.
489
- Typically, it is a dictionary of
490
- dictionaries: `eq_params` and `nn_params``, respectively the
491
- differential equation parameters and the neural network parameter
399
+ Parameters at which the loss is evaluated
492
400
  batch
493
- A PDEStatioBatch object.
494
- Such a named tuple is composed of a batch of points in the
401
+ Composed of a batch of points in the
495
402
  domain, a batch of points in the domain
496
403
  border and an optional additional batch of parameters (eg. for
497
404
  metamodeling) and an optional additional batch of observed
498
405
  inputs/outputs/parameters
499
406
  """
500
- omega_batch, _ = batch.inside_batch, batch.border_batch
501
-
502
- vmap_in_axes_x = (0,)
503
-
504
407
  # Retrieve the optional eq_params_batch
505
408
  # and update eq_params with the latter
506
409
  # and update vmap_in_axes
@@ -511,44 +414,41 @@ class LossPDEStatio(LossPDEAbstract):
511
414
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
512
415
 
513
416
  # dynamic part
514
- params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
515
417
  if self.dynamic_loss is not None:
516
418
  mse_dyn_loss = dynamic_loss_apply(
517
419
  self.dynamic_loss.evaluate,
518
420
  self.u,
519
- (omega_batch,),
520
- params_,
521
- vmap_in_axes_x + vmap_in_axes_params,
522
- self.loss_weights["dyn_loss"],
421
+ self._get_dynamic_loss_batch(batch),
422
+ _set_derivatives(params, self.derivative_keys.dyn_loss),
423
+ self.vmap_in_axes + vmap_in_axes_params,
424
+ self.loss_weights.dyn_loss,
523
425
  )
524
426
  else:
525
427
  mse_dyn_loss = jnp.array(0.0)
526
428
 
527
429
  # normalization part
528
- params_ = _set_derivatives(params, "norm_loss", self.derivative_keys)
529
- if self.normalization_loss is not None:
430
+ if self.norm_samples is not None:
530
431
  mse_norm_loss = normalization_loss_apply(
531
432
  self.u,
532
- (self.get_norm_samples(),),
533
- params_,
534
- vmap_in_axes_x + vmap_in_axes_params,
535
- self.int_length,
536
- self.loss_weights["norm_loss"],
433
+ self._get_normalization_loss_batch(batch),
434
+ _set_derivatives(params, self.derivative_keys.norm_loss),
435
+ self.vmap_in_axes + vmap_in_axes_params,
436
+ self.norm_int_length,
437
+ self.loss_weights.norm_loss,
537
438
  )
538
439
  else:
539
440
  mse_norm_loss = jnp.array(0.0)
540
441
 
541
442
  # boundary part
542
- params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
543
443
  if self.omega_boundary_condition is not None:
544
444
  mse_boundary_loss = boundary_condition_apply(
545
445
  self.u,
546
446
  batch,
547
- params_,
447
+ _set_derivatives(params, self.derivative_keys.boundary_loss),
548
448
  self.omega_boundary_fun,
549
449
  self.omega_boundary_condition,
550
450
  self.omega_boundary_dim,
551
- self.loss_weights["boundary_loss"],
451
+ self.loss_weights.boundary_loss,
552
452
  )
553
453
  else:
554
454
  mse_boundary_loss = jnp.array(0.0)
@@ -558,40 +458,21 @@ class LossPDEStatio(LossPDEAbstract):
558
458
  # update params with the batches of observed params
559
459
  params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
560
460
 
561
- params_ = _set_derivatives(params, "observations", self.derivative_keys)
562
461
  mse_observation_loss = observations_loss_apply(
563
462
  self.u,
564
- (batch.obs_batch_dict["pinn_in"],),
565
- params_,
566
- vmap_in_axes_x + vmap_in_axes_params,
463
+ self._get_observations_loss_batch(batch),
464
+ _set_derivatives(params, self.derivative_keys.observations),
465
+ self.vmap_in_axes + vmap_in_axes_params,
567
466
  batch.obs_batch_dict["val"],
568
- self.loss_weights["observations"],
467
+ self.loss_weights.observations,
569
468
  self.obs_slice,
570
469
  )
571
470
  else:
572
471
  mse_observation_loss = jnp.array(0.0)
573
472
 
574
- # Sobolev regularization
575
- params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
576
- if self.sobolev_reg is not None:
577
- mse_sobolev_loss = sobolev_reg_apply(
578
- self.u,
579
- (omega_batch,),
580
- params_,
581
- vmap_in_axes_x + vmap_in_axes_params,
582
- self.sobolev_reg,
583
- self.loss_weights["sobolev"],
584
- )
585
- else:
586
- mse_sobolev_loss = jnp.array(0.0)
587
-
588
473
  # total loss
589
474
  total_loss = (
590
- mse_dyn_loss
591
- + mse_norm_loss
592
- + mse_boundary_loss
593
- + mse_observation_loss
594
- + mse_sobolev_loss
475
+ mse_dyn_loss + mse_norm_loss + mse_boundary_loss + mse_observation_loss
595
476
  )
596
477
  return total_loss, (
597
478
  {
@@ -599,205 +480,148 @@ class LossPDEStatio(LossPDEAbstract):
599
480
  "norm_loss": mse_norm_loss,
600
481
  "boundary_loss": mse_boundary_loss,
601
482
  "observations": mse_observation_loss,
602
- "sobolev": mse_sobolev_loss,
603
483
  "initial_condition": jnp.array(0.0), # for compatibility in the
604
484
  # tree_map of SystemLoss
605
485
  }
606
486
  )
607
487
 
608
- def tree_flatten(self):
609
- children = (self.norm_key, self.norm_samples, self.loss_weights)
610
- aux_data = {
611
- "u": self.u,
612
- "dynamic_loss": self.dynamic_loss,
613
- "derivative_keys": self.derivative_keys,
614
- "omega_boundary_fun": self.omega_boundary_fun,
615
- "omega_boundary_condition": self.omega_boundary_condition,
616
- "omega_boundary_dim": self.omega_boundary_dim,
617
- "norm_borders": self.norm_borders,
618
- "sobolev_m": self.sobolev_m,
619
- "obs_slice": self.obs_slice,
620
- }
621
- return (children, aux_data)
622
-
623
- @classmethod
624
- def tree_unflatten(cls, aux_data, children):
625
- (norm_key, norm_samples, loss_weights) = children
626
- pls = cls(
627
- aux_data["u"],
628
- loss_weights,
629
- aux_data["dynamic_loss"],
630
- aux_data["derivative_keys"],
631
- aux_data["omega_boundary_fun"],
632
- aux_data["omega_boundary_condition"],
633
- aux_data["omega_boundary_dim"],
634
- norm_key,
635
- aux_data["norm_borders"],
636
- norm_samples,
637
- aux_data["sobolev_m"],
638
- aux_data["obs_slice"],
639
- )
640
- return pls
641
488
 
642
-
643
- @register_pytree_node_class
644
489
  class LossPDENonStatio(LossPDEStatio):
645
490
  r"""Loss object for a stationary partial differential equation
646
491
 
647
- .. math::
492
+ $$
648
493
  \mathcal{N}[u](t, x) = 0, \forall t \in I, \forall x \in \Omega
494
+ $$
649
495
 
650
- where :math:`\mathcal{N}[\cdot]` is a differential operator.
651
- The boundary condition is :math:`u(t, x)=u_b(t, x),\forall
652
- x\in\delta\Omega, \forall t`.
653
- The initial condition is :math:`u(0, x)=u_0(x), \forall x\in\Omega`
496
+ where $\mathcal{N}[\cdot]$ is a differential operator.
497
+ The boundary condition is $u(t, x)=u_b(t, x),\forall
498
+ x\in\delta\Omega, \forall t$.
499
+ The initial condition is $u(0, x)=u_0(x), \forall x\in\Omega$
654
500
  The additional condition of
655
- integrating to 1 can be included, i.e., :math:`\int u(t, x)\mathrm{d}x=1`.
656
-
501
+ integrating to 1 can be included, i.e., $\int u(t, x)\mathrm{d}x=1$.
502
+
503
+ Parameters
504
+ ----------
505
+ u : eqx.Module
506
+ the PINN
507
+ dynamic_loss : DynamicLoss
508
+ the non stationary PDE dynamic part of the loss, basically the differential
509
+ operator $\mathcal{N}[u](t, x)$. Should implement a method
510
+ `dynamic_loss.evaluate(t, x, u, params)`.
511
+ Can be None in order to access only some part of the evaluate call
512
+ results.
513
+ key : Key
514
+ A JAX PRNG Key for the loss class treated as an attribute. Default is
515
+ None. This field is provided for future developments and additional
516
+ losses that might need some randomness. Note that special care must be
517
+ taken when splitting the key because in-place updates are forbidden in
518
+ eqx.Modules.
519
+ reason
520
+ loss_weights : LossWeightsPDENonStatio, default=None
521
+ The loss weights for the differents term : dynamic loss,
522
+ boundary conditions if any, initial condition, normalization loss if any and
523
+ observations if any.
524
+ All fields are set to 1.0 by default.
525
+ derivative_keys : DerivativeKeysPDENonStatio, default=None
526
+ Specify which field of `params` should be differentiated for each
527
+ composant of the total loss. Particularily useful for inverse problems.
528
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
529
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
530
+ is `"nn_params"` for each composant of the loss.
531
+ omega_boundary_fun : Callable | Dict[str, Callable], default=None
532
+ The function to be matched in the border condition (can be None) or a
533
+ dictionary of such functions as values and keys as described
534
+ in `omega_boundary_condition`.
535
+ omega_boundary_condition : str | Dict[str, str], default=None
536
+ Either None (no condition, by default), or a string defining
537
+ the boundary condition (Dirichlet or Von Neumann),
538
+ or a dictionary with such strings as values. In this case,
539
+ the keys are the facets and must be in the following order:
540
+ 1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
541
+ Note that high order boundaries are currently not implemented.
542
+ A value in the dict can be None, this means we do not enforce
543
+ a particular boundary condition on this facet.
544
+ The facet called “xmin”, resp. “xmax” etc., in 2D,
545
+ refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
546
+ omega_boundary_dim : slice | Dict[str, slice], default=None
547
+ Either None, or a slice object or a dictionary of slice objects as
548
+ values and keys as described in `omega_boundary_condition`.
549
+ `omega_boundary_dim` indicates which dimension(s) of the PINN
550
+ will be forced to match the boundary condition.
551
+ Note that it must be a slice and not an integer
552
+ (but a preprocessing of the user provided argument takes care of it)
553
+ norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
554
+ Fixed sample point in the space over which to compute the
555
+ normalization constant. Default is None.
556
+ norm_int_length : float, default=None
557
+ A float. Must be provided if `norm_samples` is provided. The domain area
558
+ (or interval length in 1D) upon which we perform the numerical
559
+ integration. Default None
560
+ obs_slice : slice, default=None
561
+ slice object specifying the begininning/ending of the PINN output
562
+ that is observed (this is then useful for multidim PINN). Default is None.
563
+ initial_condition_fun : Callable, default=None
564
+ A function representing the temporal initial condition. If None
565
+ (default) then no initial condition is applied
566
+ params : InitVar[Params], default=None
567
+ The main Params object of the problem needed to instanciate the
568
+ DerivativeKeysODE if the latter is not specified.
657
569
 
658
- **Note:** LossPDENonStatio is jittable. Hence it implements the tree_flatten() and
659
- tree_unflatten methods.
660
570
  """
661
571
 
662
- def __init__(
663
- self,
664
- u,
665
- loss_weights,
666
- dynamic_loss,
667
- derivative_keys=None,
668
- omega_boundary_fun=None,
669
- omega_boundary_condition=None,
670
- omega_boundary_dim=None,
671
- initial_condition_fun=None,
672
- norm_key=None,
673
- norm_borders=None,
674
- norm_samples=None,
675
- sobolev_m=None,
676
- obs_slice=None,
677
- ):
678
- r"""
679
- Parameters
680
- ----------
681
- u
682
- the PINN object
683
- loss_weights
684
- dictionary of values for loss term ponderation
685
- Note that we can have jnp.arrays with the same dimension of
686
- `u` which then ponderates each output of `u`
687
- dynamic_loss
688
- A Dynamic loss object whose evaluate method corresponds to the
689
- dynamic term in the loss
690
- Can be None in order to access only some part of the evaluate call
691
- results.
692
- derivative_keys
693
- A dict of lists of strings. In the dict, the key must correspond to
694
- the loss term keywords. Then each of the values must correspond to keys in the parameter
695
- dictionary (*at top level only of the parameter dictionary*).
696
- It enables selecting the set of parameters
697
- with respect to which the gradients of the dynamic
698
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
699
- keywords, this is what is typically
700
- done in solving forward problems, when we only estimate the
701
- equation solution with a PINN. If some loss terms keywords are
702
- missing we set their value to ["nn_params"] by default for the same
703
- reason
704
- omega_boundary_fun
705
- The function to be matched in the border condition (can be None)
706
- or a dictionary of such function. In this case, the keys are the
707
- facets and the values are the functions. The keys must be in the
708
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
709
- "ymin", "ymax"]. Note that high order boundaries are currently not
710
- implemented. A value in the dict can be None, this means we do not
711
- enforce a particular boundary condition on this facet.
712
- The facet called "xmin", resp. "xmax" etc., in 2D,
713
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
714
- omega_boundary_condition
715
- Either None (no condition), or a string defining the boundary
716
- condition e.g. Dirichlet or Von Neumann, or a dictionary of such
717
- strings. In this case, the keys are the
718
- facets and the values are the strings. The keys must be in the
719
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
720
- "ymin", "ymax"]. Note that high order boundaries are currently not
721
- implemented. A value in the dict can be None, this means we do not
722
- enforce a particular boundary condition on this facet.
723
- The facet called "xmin", resp. "xmax" etc., in 2D,
724
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
725
- omega_boundary_dim
726
- Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
727
- the logic of omega_boundary_fun. It indicates which dimension(s) of
728
- the PINN will be forced to match the boundary condition
729
- Note that it must be a slice and not an integer (a preprocessing of the
730
- user provided argument takes care of it)
731
- initial_condition_fun
732
- A function representing the temporal initial condition. If None
733
- (default) then no initial condition is applied
734
- norm_key
735
- Jax random key to draw samples in for the Monte Carlo computation
736
- of the normalization constant. Default is None
737
- norm_borders
738
- tuple of (min, max) of the boundaray values of the space over which
739
- to integrate in the computation of the normalization constant.
740
- A list of tuple for higher dimensional problems. Default None.
741
- norm_samples
742
- Fixed sample point in the space over which to compute the
743
- normalization constant. Default is None
744
- sobolev_m
745
- An integer. Default is None.
746
- It corresponds to the Sobolev regularization order as proposed in
747
- *Convergence and error analysis of PINNs*,
748
- Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
749
- obs_slice
750
- slice object specifying the begininning/ending
751
- slice of u output(s) that is observed (this is then useful for
752
- multidim PINN). Default is None.
753
-
572
+ # NOTE static=True only for leaf attributes that are not valid JAX types
573
+ # (ie. jax.Array cannot be static) and that we do not expect to change
574
+ initial_condition_fun: Callable | None = eqx.field(
575
+ kw_only=True, default=None, static=True
576
+ )
754
577
 
578
+ def __post_init__(self, params=None):
579
+ """
580
+ Note that neither __init__ or __post_init__ are called when udating a
581
+ Module with eqx.tree_at!
755
582
  """
583
+ super().__post_init__(
584
+ params=params
585
+ ) # because __init__ or __post_init__ of Base
586
+ # class is not automatically called
756
587
 
757
- super().__init__(
758
- u,
759
- loss_weights,
760
- dynamic_loss,
761
- derivative_keys,
762
- omega_boundary_fun,
763
- omega_boundary_condition,
764
- omega_boundary_dim,
765
- norm_key,
766
- norm_borders,
767
- norm_samples,
768
- sobolev_m=sobolev_m,
769
- obs_slice=obs_slice,
770
- )
771
- if initial_condition_fun is None:
588
+ self.vmap_in_axes = (0, 0) # for t and x
589
+
590
+ if self.initial_condition_fun is None:
772
591
  warnings.warn(
773
592
  "Initial condition wasn't provided. Be sure to cover for that"
774
593
  "case (e.g by. hardcoding it into the PINN output)."
775
594
  )
776
- self.initial_condition_fun = initial_condition_fun
777
-
778
- self.sobolev_m = sobolev_m
779
- if self.sobolev_m is not None:
780
- # This overwrite the wrongly initialized self.sobolev_reg with
781
- # statio=True in the LossPDEStatio init
782
- self.sobolev_reg = _sobolev(self.u, self.sobolev_m, statio=False)
783
- # we return a function, that way
784
- # the order of sobolev_m is static and the conditional in the recursive
785
- # function is properly set
786
- else:
787
- self.sobolev_reg = None
788
595
 
789
- for k in _LOSS_WEIGHT_KEYS_PDENONSTATIO:
790
- if k not in self.loss_weights.keys():
791
- self.loss_weights[k] = 0
596
+ def _get_dynamic_loss_batch(
597
+ self, batch: PDENonStatioBatch
598
+ ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
599
+ times_batch = batch.times_x_inside_batch[:, 0:1]
600
+ omega_batch = batch.times_x_inside_batch[:, 1:]
601
+ return (times_batch, omega_batch)
602
+
603
+ def _get_normalization_loss_batch(
604
+ self, batch: PDENonStatioBatch
605
+ ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
606
+ return (
607
+ batch.times_x_inside_batch[:, 0:1],
608
+ self.norm_samples,
609
+ )
610
+
611
+ def _get_observations_loss_batch(
612
+ self, batch: PDENonStatioBatch
613
+ ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
614
+ return (
615
+ batch.obs_batch_dict["pinn_in"][:, 0:1],
616
+ batch.obs_batch_dict["pinn_in"][:, 1:],
617
+ )
792
618
 
793
619
  def __call__(self, *args, **kwargs):
794
620
  return self.evaluate(*args, **kwargs)
795
621
 
796
622
  def evaluate(
797
- self,
798
- params,
799
- batch,
800
- ):
623
+ self, params: Params, batch: PDENonStatioBatch
624
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
801
625
  """
802
626
  Evaluate the loss function at a batch of points for given parameters.
803
627
 
@@ -805,191 +629,54 @@ class LossPDENonStatio(LossPDEStatio):
805
629
  Parameters
806
630
  ---------
807
631
  params
808
- The dictionary of parameters of the model.
809
- Typically, it is a dictionary of
810
- dictionaries: `eq_params` and `nn_params`, respectively the
811
- differential equation parameters and the neural network parameter
632
+ Parameters at which the loss is evaluated
812
633
  batch
813
- A PDENonStatioBatch object.
814
- Such a named tuple is composed of a batch of points in
634
+ Composed of a batch of points in
815
635
  the domain, a batch of points in the domain
816
636
  border, a batch of time points and an optional additional batch
817
637
  of parameters (eg. for metamodeling) and an optional additional batch of observed
818
638
  inputs/outputs/parameters
819
639
  """
820
-
821
- times_batch = batch.times_x_inside_batch[:, 0:1]
822
640
  omega_batch = batch.times_x_inside_batch[:, 1:]
823
- n = omega_batch.shape[0]
824
-
825
- vmap_in_axes_x_t = (0, 0)
826
641
 
827
642
  # Retrieve the optional eq_params_batch
828
643
  # and update eq_params with the latter
829
644
  # and update vmap_in_axes
830
645
  if batch.param_batch_dict is not None:
831
- eq_params_batch_dict = batch.param_batch_dict
832
-
833
- # feed the eq_params with the batch
834
- for k in eq_params_batch_dict.keys():
835
- params["eq_params"][k] = eq_params_batch_dict[k]
646
+ # update eq_params with the batches of generated params
647
+ params = _update_eq_params_dict(params, batch.param_batch_dict)
836
648
 
837
649
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
838
650
 
839
- # dynamic part
840
- params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
841
- if self.dynamic_loss is not None:
842
- mse_dyn_loss = dynamic_loss_apply(
843
- self.dynamic_loss.evaluate,
844
- self.u,
845
- (times_batch, omega_batch),
846
- params_,
847
- vmap_in_axes_x_t + vmap_in_axes_params,
848
- self.loss_weights["dyn_loss"],
849
- )
850
- else:
851
- mse_dyn_loss = jnp.array(0.0)
852
-
853
- # normalization part
854
- params_ = _set_derivatives(params, "norm_loss", self.derivative_keys)
855
- if self.normalization_loss is not None:
856
- mse_norm_loss = normalization_loss_apply(
857
- self.u,
858
- (times_batch, self.get_norm_samples()),
859
- params_,
860
- vmap_in_axes_x_t + vmap_in_axes_params,
861
- self.int_length,
862
- self.loss_weights["norm_loss"],
863
- )
864
- else:
865
- mse_norm_loss = jnp.array(0.0)
866
-
867
- # boundary part
868
- params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
869
- if self.omega_boundary_fun is not None:
870
- mse_boundary_loss = boundary_condition_apply(
871
- self.u,
872
- batch,
873
- params_,
874
- self.omega_boundary_fun,
875
- self.omega_boundary_condition,
876
- self.omega_boundary_dim,
877
- self.loss_weights["boundary_loss"],
878
- )
879
- else:
880
- mse_boundary_loss = jnp.array(0.0)
651
+ # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
652
+ # mse_observation_loss we use the evaluate from parent class
653
+ partial_mse, partial_mse_terms = super().evaluate(params, batch)
881
654
 
882
655
  # initial condition
883
- params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
884
656
  if self.initial_condition_fun is not None:
885
657
  mse_initial_condition = initial_condition_apply(
886
658
  self.u,
887
659
  omega_batch,
888
- params_,
660
+ _set_derivatives(params, self.derivative_keys.initial_condition),
889
661
  (0,) + vmap_in_axes_params,
890
662
  self.initial_condition_fun,
891
- n,
892
- self.loss_weights["initial_condition"],
663
+ omega_batch.shape[0],
664
+ self.loss_weights.initial_condition,
893
665
  )
894
666
  else:
895
667
  mse_initial_condition = jnp.array(0.0)
896
668
 
897
- # Observation mse
898
- if batch.obs_batch_dict is not None:
899
- # update params with the batches of observed params
900
- params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
901
-
902
- params_ = _set_derivatives(params, "observations", self.derivative_keys)
903
- mse_observation_loss = observations_loss_apply(
904
- self.u,
905
- (
906
- batch.obs_batch_dict["pinn_in"][:, 0:1],
907
- batch.obs_batch_dict["pinn_in"][:, 1:],
908
- ),
909
- params_,
910
- vmap_in_axes_x_t + vmap_in_axes_params,
911
- batch.obs_batch_dict["val"],
912
- self.loss_weights["observations"],
913
- self.obs_slice,
914
- )
915
- else:
916
- mse_observation_loss = jnp.array(0.0)
917
-
918
- # Sobolev regularization
919
- params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
920
- if self.sobolev_reg is not None:
921
- mse_sobolev_loss = sobolev_reg_apply(
922
- self.u,
923
- (omega_batch, times_batch),
924
- params_,
925
- vmap_in_axes_x_t + vmap_in_axes_params,
926
- self.sobolev_reg,
927
- self.loss_weights["sobolev"],
928
- )
929
- else:
930
- mse_sobolev_loss = jnp.array(0.0)
931
-
932
669
  # total loss
933
- total_loss = (
934
- mse_dyn_loss
935
- + mse_norm_loss
936
- + mse_boundary_loss
937
- + mse_initial_condition
938
- + mse_observation_loss
939
- + mse_sobolev_loss
940
- )
941
-
942
- return total_loss, (
943
- {
944
- "dyn_loss": mse_dyn_loss,
945
- "norm_loss": mse_norm_loss,
946
- "boundary_loss": mse_boundary_loss,
947
- "initial_condition": mse_initial_condition,
948
- "observations": mse_observation_loss,
949
- "sobolev": mse_sobolev_loss,
950
- }
951
- )
670
+ total_loss = partial_mse + mse_initial_condition
952
671
 
953
- def tree_flatten(self):
954
- children = (self.norm_key, self.norm_samples, self.loss_weights)
955
- aux_data = {
956
- "u": self.u,
957
- "dynamic_loss": self.dynamic_loss,
958
- "derivative_keys": self.derivative_keys,
959
- "omega_boundary_fun": self.omega_boundary_fun,
960
- "omega_boundary_condition": self.omega_boundary_condition,
961
- "omega_boundary_dim": self.omega_boundary_dim,
962
- "initial_condition_fun": self.initial_condition_fun,
963
- "norm_borders": self.norm_borders,
964
- "sobolev_m": self.sobolev_m,
965
- "obs_slice": self.obs_slice,
672
+ return total_loss, {
673
+ **partial_mse_terms,
674
+ "initial_condition": mse_initial_condition,
966
675
  }
967
- return (children, aux_data)
968
-
969
- @classmethod
970
- def tree_unflatten(cls, aux_data, children):
971
- (norm_key, norm_samples, loss_weights) = children
972
- pls = cls(
973
- aux_data["u"],
974
- loss_weights,
975
- aux_data["dynamic_loss"],
976
- aux_data["derivative_keys"],
977
- aux_data["omega_boundary_fun"],
978
- aux_data["omega_boundary_condition"],
979
- aux_data["omega_boundary_dim"],
980
- aux_data["initial_condition_fun"],
981
- norm_key,
982
- aux_data["norm_borders"],
983
- norm_samples,
984
- aux_data["sobolev_m"],
985
- aux_data["obs_slice"],
986
- )
987
- return pls
988
676
 
989
677
 
990
- @register_pytree_node_class
991
- class SystemLossPDE:
992
- """
678
+ class SystemLossPDE(eqx.Module):
679
+ r"""
993
680
  Class to implement a system of PDEs.
994
681
  The goal is to give maximum freedom to the user. The class is created with
995
682
  a dict of dynamic loss, and dictionaries of all the objects that are used
@@ -1003,190 +690,182 @@ class SystemLossPDE:
1003
690
  Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
1004
691
  solution.
1005
692
 
1006
- **Note:** SystemLossPDE is jittable. Hence it implements the tree_flatten() and
1007
- tree_unflatten methods.
693
+ Parameters
694
+ ----------
695
+ u_dict : Dict[str, eqx.Module]
696
+ dict of PINNs
697
+ loss_weights : LossWeightsPDEDict
698
+ A dictionary of LossWeightsODE
699
+ derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
700
+ A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
701
+ specifying what field of `params`
702
+ should be used during gradient computations for each of the terms of
703
+ the total loss, for each of the loss in the system. Default is
704
+ `"nn_params`" everywhere.
705
+ dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
706
+ A dict of dynamic part of the loss, basically the differential
707
+ operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
708
+ key_dict : Dict[str, Key], default=None
709
+ A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
710
+ match that of u_dict. See LossPDEStatio or LossPDENonStatio for
711
+ more details.
712
+ omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
713
+ A dict of of function or of dict of functions or of None
714
+ (see doc for `omega_boundary_fun` in
715
+ LossPDEStatio or LossPDENonStatio). Default is None.
716
+ Must share the keys of `u_dict`.
717
+ omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
718
+ A dict of strings or of dict of strings or of None
719
+ (see doc for `omega_boundary_condition_dict` in
720
+ LossPDEStatio or LossPDENonStatio). Default is None.
721
+ Must share the keys of `u_dict`
722
+ omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
723
+ A dict of slices or of dict of slices or of None
724
+ (see doc for `omega_boundary_dim` in
725
+ LossPDEStatio or LossPDENonStatio). Default is None.
726
+ Must share the keys of `u_dict`
727
+ initial_condition_fun_dict : Dict[str, Callable | None], default=None
728
+ A dict of functions representing the temporal initial condition (None
729
+ value is possible). If None
730
+ (default) then no temporal boundary condition is applied
731
+ Must share the keys of `u_dict`
732
+ norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
733
+ A dict of fixed sample point in the space over which to compute the
734
+ normalization constant. Default is None
735
+ Must share the keys of `u_dict`
736
+ norm_int_length_dict : Dict[str, float | None] | None, default=None
737
+ A dict of Float. The domain area
738
+ (or interval length in 1D) upon which we perform the numerical
739
+ integration for each element of u_dict.
740
+ Default is None
741
+ Must share the keys of `u_dict`
742
+ obs_slice_dict : Dict[str, slice | None] | None, default=None
743
+ dict of obs_slice, with keys from `u_dict` to designate the
744
+ output(s) channels that are forced to observed values, for each
745
+ PINNs. Default is None. But if a value is given, all the entries of
746
+ `u_dict` must be represented here with default value `jnp.s_[...]`
747
+ if no particular slice is to be given
748
+ params : InitVar[ParamsDict], default=None
749
+ The main Params object of the problem needed to instanciate the
750
+ DerivativeKeysODE if the latter is not specified.
751
+
1008
752
  """
1009
753
 
1010
- def __init__(
1011
- self,
1012
- u_dict,
1013
- loss_weights,
1014
- dynamic_loss_dict,
1015
- nn_type_dict,
1016
- derivative_keys_dict=None,
1017
- omega_boundary_fun_dict=None,
1018
- omega_boundary_condition_dict=None,
1019
- omega_boundary_dim_dict=None,
1020
- initial_condition_fun_dict=None,
1021
- norm_key_dict=None,
1022
- norm_borders_dict=None,
1023
- norm_samples_dict=None,
1024
- sobolev_m_dict=None,
1025
- obs_slice_dict=None,
1026
- ):
1027
- r"""
1028
- Parameters
1029
- ----------
1030
- u_dict
1031
- A dict of PINNs
1032
- loss_weights
1033
- A dictionary of dictionaries with values used to
1034
- ponderate each term in the loss
1035
- function. The keys of the nested
1036
- dictionaries must share the keys of `u_dict`. Note that the values
1037
- at the leaf level can have jnp.arrays with the same dimension of
1038
- `u` which then ponderates each output of `u`
1039
- dynamic_loss_dict
1040
- A dict of dynamic part of the loss, basically the differential
1041
- operator :math:`\mathcal{N}[u](t)`.
1042
- nn_type_dict
1043
- A dict whose keys are that of u_dict whose value is either
1044
- `nn_statio` or `nn_nonstatio` which signifies either the PINN has a
1045
- time component in input or not.
1046
- derivative_keys_dict
1047
- A dict of derivative keys as defined in LossODE. The key of this
1048
- dict must be that of `dynamic_loss_dict` at least and specify how
1049
- to compute gradient for the `dyn_loss` loss term at least (see the
1050
- check at the beginning of the present `__init__` function.
1051
- Other keys of this dict might be that of `u_dict` to specify how to
1052
- compute gradients for all the different constraints. If those keys
1053
- are not specified then the default behaviour for `derivative_keys`
1054
- of LossODE is used
1055
- omega_boundary_fun_dict
1056
- A dict of dict of functions (see doc for `omega_boundary_fun` in
1057
- LossPDEStatio or LossPDENonStatio). Default is None.
1058
- Must share the keys of `u_dict`.
1059
- omega_boundary_condition_dict
1060
- A dict of dict of strings (see doc for
1061
- `omega_boundary_condition_dict` in
1062
- LossPDEStatio or LossPDENonStatio). Default is None.
1063
- Must share the keys of `u_dict`
1064
- omega_boundary_dim_dict
1065
- A dict of dict of slices (see doc for `omega_boundary_dim` in
1066
- LossPDEStatio or LossPDENonStatio). Default is None.
1067
- Must share the keys of `u_dict`
1068
- initial_condition_fun_dict
1069
- A dict of functions representing the temporal initial condition. If None
1070
- (default) then no temporal boundary condition is applied
1071
- Must share the keys of `u_dict`
1072
- norm_key_dict
1073
- A dict of Jax random keys to draw samples in for the Monte Carlo computation
1074
- of the normalization constant. Default is None
1075
- Must share the keys of `u_dict`
1076
- norm_borders_dict
1077
- A dict of tuples of (min, max) of the boundaray values of the space over which
1078
- to integrate in the computation of the normalization constant.
1079
- A list of tuple for higher dimensional problems. Default None.
1080
- Must share the keys of `u_dict`
1081
- norm_samples_dict
1082
- A dict of fixed sample point in the space over which to compute the
1083
- normalization constant. Default is None
1084
- Must share the keys of `u_dict`
1085
- sobolev_m
1086
- Default is None. A dictionary of integers, one per key which must
1087
- match `u_dict`.
1088
- It corresponds to the Sobolev regularization order as proposed in
1089
- *Convergence and error analysis of PINNs*,
1090
- Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
1091
- obs_slice_dict
1092
- dict of obs_slice, with keys from `u_dict` to designate the
1093
- output(s) channels that are forced to observed values, for each
1094
- PINNs. Default is None. But if a value is given, all the entries of
1095
- `u_dict` must be represented here with default value `jnp.s_[...]`
1096
- if no particular slice is to be given
1097
-
1098
-
1099
- Raises
1100
- ------
1101
- ValueError
1102
- if initial condition is not a dict of tuple
1103
- ValueError
1104
- if the dictionaries that should share the keys of u_dict do not
1105
- """
754
+ # NOTE static=True only for leaf attributes that are not valid JAX types
755
+ # (ie. jax.Array cannot be static) and that we do not expect to change
756
+ u_dict: Dict[str, eqx.Module]
757
+ dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
758
+ key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
759
+ derivative_keys_dict: Dict[
760
+ str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
761
+ ] = eqx.field(kw_only=True, default=None)
762
+ omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
763
+ eqx.field(kw_only=True, default=None, static=True)
764
+ )
765
+ omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
766
+ eqx.field(kw_only=True, default=None, static=True)
767
+ )
768
+ omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
769
+ eqx.field(kw_only=True, default=None, static=True)
770
+ )
771
+ initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
772
+ kw_only=True, default=None, static=True
773
+ )
774
+ norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
775
+ eqx.field(kw_only=True, default=None)
776
+ )
777
+ norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
778
+ kw_only=True, default=None
779
+ )
780
+ obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
781
+ kw_only=True, default=None, static=True
782
+ )
783
+
784
+ # For the user loss_weights are passed as a LossWeightsPDEDict (with internal
785
+ # dictionary having keys in u_dict and / or dynamic_loss_dict)
786
+ loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
787
+ kw_only=True, default=None
788
+ )
789
+ params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
790
+
791
+ # following have init=False and are set in the __post_init__
792
+ u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
793
+ init=False
794
+ )
795
+ derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
796
+ eqx.field(init=False)
797
+ )
798
+ u_dict_with_none: Dict[str, None] = eqx.field(init=False)
799
+ # internally the loss weights are handled with a dictionary
800
+ _loss_weights: Dict[str, dict] = eqx.field(init=False)
801
+
802
+ def __post_init__(self, loss_weights=None, params_dict=None):
1106
803
  # a dictionary that will be useful at different places
1107
- self.u_dict_with_none = {k: None for k in u_dict.keys()}
804
+ self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
1108
805
  # First, for all the optional dict,
1109
806
  # if the user did not provide at all this optional argument,
1110
807
  # we make sure there is a null ponderating loss_weight and we
1111
808
  # create a dummy dict with the required keys and all the values to
1112
809
  # None
1113
- if omega_boundary_fun_dict is None:
810
+ if self.key_dict is None:
811
+ self.key_dict = self.u_dict_with_none
812
+ if self.omega_boundary_fun_dict is None:
1114
813
  self.omega_boundary_fun_dict = self.u_dict_with_none
1115
- else:
1116
- self.omega_boundary_fun_dict = omega_boundary_fun_dict
1117
- if omega_boundary_condition_dict is None:
814
+ if self.omega_boundary_condition_dict is None:
1118
815
  self.omega_boundary_condition_dict = self.u_dict_with_none
1119
- else:
1120
- self.omega_boundary_condition_dict = omega_boundary_condition_dict
1121
- if omega_boundary_dim_dict is None:
816
+ if self.omega_boundary_dim_dict is None:
1122
817
  self.omega_boundary_dim_dict = self.u_dict_with_none
1123
- else:
1124
- self.omega_boundary_dim_dict = omega_boundary_dim_dict
1125
- if initial_condition_fun_dict is None:
818
+ if self.initial_condition_fun_dict is None:
1126
819
  self.initial_condition_fun_dict = self.u_dict_with_none
1127
- else:
1128
- self.initial_condition_fun_dict = initial_condition_fun_dict
1129
- if norm_key_dict is None:
1130
- self.norm_key_dict = self.u_dict_with_none
1131
- else:
1132
- self.norm_key_dict = norm_key_dict
1133
- if norm_borders_dict is None:
1134
- self.norm_borders_dict = self.u_dict_with_none
1135
- else:
1136
- self.norm_borders_dict = norm_borders_dict
1137
- if norm_samples_dict is None:
820
+ if self.norm_samples_dict is None:
1138
821
  self.norm_samples_dict = self.u_dict_with_none
1139
- else:
1140
- self.norm_samples_dict = norm_samples_dict
1141
- if sobolev_m_dict is None:
1142
- self.sobolev_m_dict = self.u_dict_with_none
1143
- else:
1144
- self.sobolev_m_dict = sobolev_m_dict
1145
- if obs_slice_dict is None:
1146
- self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
1147
- else:
1148
- self.obs_slice_dict = obs_slice_dict
1149
- if u_dict.keys() != obs_slice_dict.keys():
822
+ if self.norm_int_length_dict is None:
823
+ self.norm_int_length_dict = self.u_dict_with_none
824
+ if self.obs_slice_dict is None:
825
+ self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
826
+ if self.u_dict.keys() != self.obs_slice_dict.keys():
1150
827
  raise ValueError("obs_slice_dict should have same keys as u_dict")
1151
- if derivative_keys_dict is None:
828
+ if self.derivative_keys_dict is None:
1152
829
  self.derivative_keys_dict = {
1153
830
  k: None
1154
- for k in set(list(dynamic_loss_dict.keys()) + list(u_dict.keys()))
831
+ for k in set(
832
+ list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
833
+ )
1155
834
  }
1156
835
  # set() because we can have duplicate entries and in this case we
1157
836
  # say it corresponds to the same derivative_keys_dict entry
1158
- else:
1159
- self.derivative_keys_dict = derivative_keys_dict
1160
- # but then if the user did not provide anything, we must at least have
1161
- # a default value for the dynamic_loss_dict keys entries in
1162
- # self.derivative_keys_dict since the computation of dynamic losses is
1163
- # made without create a lossODE object that would provide the
1164
- # default values
1165
- for k in dynamic_loss_dict.keys():
837
+ # we need both because the constraints (all but dyn_loss) will be
838
+ # done by iterating on u_dict while the dyn_loss will be by
839
+ # iterating on dynamic_loss_dict. So each time we will require dome
840
+ # derivative_keys_dict
841
+
842
+ # derivative keys for the u_constraints. Note that we create missing
843
+ # DerivativeKeysODE around a Params object and not ParamsDict
844
+ # this works because u_dict.keys == params_dict.nn_params.keys()
845
+ for k in self.u_dict.keys():
1166
846
  if self.derivative_keys_dict[k] is None:
1167
- self.derivative_keys_dict[k] = {"dyn_loss": ["nn_params"]}
847
+ if self.u_dict[k].eq_type == "statio_PDE":
848
+ self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
849
+ params=params_dict.extract_params(k)
850
+ )
851
+ else:
852
+ self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
853
+ params=params_dict.extract_params(k)
854
+ )
1168
855
 
1169
856
  # Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
1170
857
  if (
1171
- u_dict.keys() != nn_type_dict.keys()
1172
- or u_dict.keys() != self.omega_boundary_fun_dict.keys()
1173
- or u_dict.keys() != self.omega_boundary_condition_dict.keys()
1174
- or u_dict.keys() != self.omega_boundary_dim_dict.keys()
1175
- or u_dict.keys() != self.initial_condition_fun_dict.keys()
1176
- or u_dict.keys() != self.norm_key_dict.keys()
1177
- or u_dict.keys() != self.norm_borders_dict.keys()
1178
- or u_dict.keys() != self.norm_samples_dict.keys()
1179
- or u_dict.keys() != self.sobolev_m_dict.keys()
858
+ self.u_dict.keys() != self.key_dict.keys()
859
+ or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
860
+ or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
861
+ or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
862
+ or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
863
+ or self.u_dict.keys() != self.norm_samples_dict.keys()
864
+ or self.u_dict.keys() != self.norm_int_length_dict.keys()
1180
865
  ):
1181
866
  raise ValueError("All the dicts concerning the PINNs should have same keys")
1182
867
 
1183
- self.dynamic_loss_dict = dynamic_loss_dict
1184
- self.u_dict = u_dict
1185
- # TODO nn_type should become a class attribute now that we have PINN
1186
- # class and SPINNs class
1187
- self.nn_type_dict = nn_type_dict
1188
-
1189
- self.loss_weights = loss_weights # This calls the setter
868
+ self._loss_weights = self.set_loss_weights(loss_weights)
1190
869
 
1191
870
  # Third, in order not to benefit from LossPDEStatio and
1192
871
  # LossPDENonStatio and in order to factorize code, we create internally
@@ -1194,95 +873,93 @@ class SystemLossPDE:
1194
873
  # We will not use the dynamic loss term
1195
874
  self.u_constraints_dict = {}
1196
875
  for i in self.u_dict.keys():
1197
- if self.nn_type_dict[i] == "nn_statio":
876
+ if self.u_dict[i].eq_type == "statio_PDE":
1198
877
  self.u_constraints_dict[i] = LossPDEStatio(
1199
- u=u_dict[i],
1200
- loss_weights={
1201
- "dyn_loss": 0.0,
1202
- "norm_loss": 1.0,
1203
- "boundary_loss": 1.0,
1204
- "observations": 1.0,
1205
- "sobolev": 1.0,
1206
- },
878
+ u=self.u_dict[i],
879
+ loss_weights=LossWeightsPDENonStatio(
880
+ dyn_loss=0.0,
881
+ norm_loss=1.0,
882
+ boundary_loss=1.0,
883
+ observations=1.0,
884
+ initial_condition=1.0,
885
+ ),
1207
886
  dynamic_loss=None,
887
+ key=self.key_dict[i],
1208
888
  derivative_keys=self.derivative_keys_dict[i],
1209
889
  omega_boundary_fun=self.omega_boundary_fun_dict[i],
1210
890
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
1211
891
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
1212
- norm_key=self.norm_key_dict[i],
1213
- norm_borders=self.norm_borders_dict[i],
1214
892
  norm_samples=self.norm_samples_dict[i],
1215
- sobolev_m=self.sobolev_m_dict[i],
893
+ norm_int_length=self.norm_int_length_dict[i],
1216
894
  obs_slice=self.obs_slice_dict[i],
1217
895
  )
1218
- elif self.nn_type_dict[i] == "nn_nonstatio":
896
+ elif self.u_dict[i].eq_type == "nonstatio_PDE":
1219
897
  self.u_constraints_dict[i] = LossPDENonStatio(
1220
- u=u_dict[i],
1221
- loss_weights={
1222
- "dyn_loss": 0.0,
1223
- "norm_loss": 1.0,
1224
- "boundary_loss": 1.0,
1225
- "observations": 1.0,
1226
- "initial_condition": 1.0,
1227
- "sobolev": 1.0,
1228
- },
898
+ u=self.u_dict[i],
899
+ loss_weights=LossWeightsPDENonStatio(
900
+ dyn_loss=0.0,
901
+ norm_loss=1.0,
902
+ boundary_loss=1.0,
903
+ observations=1.0,
904
+ initial_condition=1.0,
905
+ ),
1229
906
  dynamic_loss=None,
907
+ key=self.key_dict[i],
1230
908
  derivative_keys=self.derivative_keys_dict[i],
1231
909
  omega_boundary_fun=self.omega_boundary_fun_dict[i],
1232
910
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
1233
911
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
1234
912
  initial_condition_fun=self.initial_condition_fun_dict[i],
1235
- norm_key=self.norm_key_dict[i],
1236
- norm_borders=self.norm_borders_dict[i],
1237
913
  norm_samples=self.norm_samples_dict[i],
1238
- sobolev_m=self.sobolev_m_dict[i],
914
+ norm_int_length=self.norm_int_length_dict[i],
915
+ obs_slice=self.obs_slice_dict[i],
1239
916
  )
1240
917
  else:
1241
918
  raise ValueError(
1242
- f"Wrong value for nn_type_dict[i], got {nn_type_dict[i]}"
919
+ "Wrong value for self.u_dict[i].eq_type[i], "
920
+ f"got {self.u_dict[i].eq_type[i]}"
1243
921
  )
1244
922
 
1245
- # for convenience in the tree_map of evaluate,
1246
- # we separate the two derivative keys dict
1247
- self.derivative_keys_dyn_loss_dict = {
1248
- k: self.derivative_keys_dict[k]
1249
- for k in self.dynamic_loss_dict.keys() & self.derivative_keys_dict.keys()
1250
- }
1251
- self.derivative_keys_u_dict = {
1252
- k: self.derivative_keys_dict[k]
1253
- for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
1254
- }
923
+ # derivative keys for the dynamic loss. Note that we create a
924
+ # DerivativeKeysODE around a ParamsDict object because a whole
925
+ # params_dict is feed to DynamicLoss.evaluate functions (extract_params
926
+ # happen inside it)
927
+ self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
1255
928
 
1256
929
  # also make sure we only have PINNs or SPINNs
1257
930
  if not (
1258
- all(isinstance(value, PINN) for value in u_dict.values())
1259
- or all(isinstance(value, SPINN) for value in u_dict.values())
931
+ all(isinstance(value, PINN) for value in self.u_dict.values())
932
+ or all(isinstance(value, SPINN) for value in self.u_dict.values())
1260
933
  ):
1261
934
  raise ValueError(
1262
935
  "We only accept dictionary of PINNs or dictionary of SPINNs"
1263
936
  )
1264
937
 
1265
- @property
1266
- def loss_weights(self):
1267
- return self._loss_weights
1268
-
1269
- @loss_weights.setter
1270
- def loss_weights(self, value):
1271
- self._loss_weights = {}
1272
- for k, v in value.items():
938
+ def set_loss_weights(
939
+ self, loss_weights_init: LossWeightsPDEDict
940
+ ) -> dict[str, dict]:
941
+ """
942
+ This rather complex function enables the user to specify a simple
943
+ loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
944
+ for ponderating values being applied to all the equations of the
945
+ system... So all the transformations are handled here
946
+ """
947
+ _loss_weights = {}
948
+ for k in fields(loss_weights_init):
949
+ v = getattr(loss_weights_init, k.name)
1273
950
  if isinstance(v, dict):
1274
- for kk, vv in v.items():
951
+ for vv in v.keys():
1275
952
  if not isinstance(vv, (int, float)) and not (
1276
- isinstance(vv, jnp.ndarray)
953
+ isinstance(vv, Array)
1277
954
  and ((vv.shape == (1,) or len(vv.shape) == 0))
1278
955
  ):
1279
956
  # TODO improve that
1280
957
  raise ValueError(
1281
958
  f"loss values cannot be vectorial here, got {vv}"
1282
959
  )
1283
- if k == "dyn_loss":
960
+ if k.name == "dyn_loss":
1284
961
  if v.keys() == self.dynamic_loss_dict.keys():
1285
- self._loss_weights[k] = v
962
+ _loss_weights[k.name] = v
1286
963
  else:
1287
964
  raise ValueError(
1288
965
  "Keys in nested dictionary of loss_weights"
@@ -1290,51 +967,36 @@ class SystemLossPDE:
1290
967
  )
1291
968
  else:
1292
969
  if v.keys() == self.u_dict.keys():
1293
- self._loss_weights[k] = v
970
+ _loss_weights[k.name] = v
1294
971
  else:
1295
972
  raise ValueError(
1296
973
  "Keys in nested dictionary of loss_weights"
1297
974
  " do not match u_dict keys"
1298
975
  )
976
+ if v is None:
977
+ _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
1299
978
  else:
1300
979
  if not isinstance(v, (int, float)) and not (
1301
- isinstance(v, jnp.ndarray)
1302
- and ((v.shape == (1,) or len(v.shape) == 0))
980
+ isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
1303
981
  ):
1304
982
  # TODO improve that
1305
983
  raise ValueError(f"loss values cannot be vectorial here, got {v}")
1306
- if k == "dyn_loss":
1307
- self._loss_weights[k] = {
984
+ if k.name == "dyn_loss":
985
+ _loss_weights[k.name] = {
1308
986
  kk: v for kk in self.dynamic_loss_dict.keys()
1309
987
  }
1310
988
  else:
1311
- self._loss_weights[k] = {kk: v for kk in self.u_dict.keys()}
1312
- # Some special checks below
1313
- if all(v is None for k, v in self.sobolev_m_dict.items()):
1314
- self._loss_weights["sobolev"] = {k: 0 for k in self.u_dict.keys()}
1315
- if "observations" not in value.keys():
1316
- self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
1317
- if all(v is None for k, v in self.omega_boundary_fun_dict.items()) or all(
1318
- v is None for k, v in self.omega_boundary_condition_dict.items()
1319
- ):
1320
- self._loss_weights["boundary_loss"] = {k: 0 for k in self.u_dict.keys()}
1321
- if (
1322
- all(v is None for k, v in self.norm_key_dict.items())
1323
- or all(v is None for k, v in self.norm_borders_dict.items())
1324
- or all(v is None for k, v in self.norm_samples_dict.items())
1325
- ):
1326
- self._loss_weights["norm_loss"] = {k: 0 for k in self.u_dict.keys()}
1327
- if all(v is None for k, v in self.initial_condition_fun_dict.items()):
1328
- self._loss_weights["initial_condition"] = {k: 0 for k in self.u_dict.keys()}
989
+ _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
990
+ return _loss_weights
1329
991
 
1330
992
  def __call__(self, *args, **kwargs):
1331
993
  return self.evaluate(*args, **kwargs)
1332
994
 
1333
995
  def evaluate(
1334
996
  self,
1335
- params_dict,
1336
- batch,
1337
- ):
997
+ params_dict: ParamsDict,
998
+ batch: PDEStatioBatch | PDENonStatioBatch,
999
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
1338
1000
  """
1339
1001
  Evaluate the loss function at a batch of points for given parameters.
1340
1002
 
@@ -1342,12 +1004,8 @@ class SystemLossPDE:
1342
1004
  Parameters
1343
1005
  ---------
1344
1006
  params_dict
1345
- A dictionary of dictionaries of parameters of the model.
1346
- Typically, it is a dictionary of dictionaries of
1347
- dictionaries: `eq_params` and `nn_params``, respectively the
1348
- differential equation parameters and the neural network parameter
1007
+ Parameters at which the losses of the system are evaluated
1349
1008
  batch
1350
- A PDEStatioBatch or PDENonStatioBatch object.
1351
1009
  Such named tuples are composed of batch of points in the
1352
1010
  domain, a batch of points in the domain
1353
1011
  border, (a batch of time points a for PDENonStatioBatch) and an
@@ -1355,7 +1013,7 @@ class SystemLossPDE:
1355
1013
  and an optional additional batch of observed
1356
1014
  inputs/outputs/parameters
1357
1015
  """
1358
- if self.u_dict.keys() != params_dict["nn_params"].keys():
1016
+ if self.u_dict.keys() != params_dict.nn_params.keys():
1359
1017
  raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1360
1018
 
1361
1019
  if isinstance(batch, PDEStatioBatch):
@@ -1378,22 +1036,22 @@ class SystemLossPDE:
1378
1036
  if batch.param_batch_dict is not None:
1379
1037
  eq_params_batch_dict = batch.param_batch_dict
1380
1038
 
1039
+ # TODO
1381
1040
  # feed the eq_params with the batch
1382
1041
  for k in eq_params_batch_dict.keys():
1383
- params_dict["eq_params"][k] = eq_params_batch_dict[k]
1042
+ params_dict.eq_params[k] = eq_params_batch_dict[k]
1384
1043
 
1385
1044
  vmap_in_axes_params = _get_vmap_in_axes_params(
1386
1045
  batch.param_batch_dict, params_dict
1387
1046
  )
1388
1047
 
1389
- def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
1048
+ def dyn_loss_for_one_key(dyn_loss, loss_weight):
1390
1049
  """The function used in tree_map"""
1391
- params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
1392
1050
  return dynamic_loss_apply(
1393
1051
  dyn_loss.evaluate,
1394
1052
  self.u_dict,
1395
1053
  batches,
1396
- params_dict_,
1054
+ _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1397
1055
  vmap_in_axes_x_or_x_t + vmap_in_axes_params,
1398
1056
  loss_weight,
1399
1057
  u_type=type(list(self.u_dict.values())[0]),
@@ -1402,8 +1060,14 @@ class SystemLossPDE:
1402
1060
  dyn_loss_mse_dict = jax.tree_util.tree_map(
1403
1061
  dyn_loss_for_one_key,
1404
1062
  self.dynamic_loss_dict,
1405
- self.derivative_keys_dyn_loss_dict,
1406
1063
  self._loss_weights["dyn_loss"],
1064
+ is_leaf=lambda x: isinstance(
1065
+ x, (PDEStatio, PDENonStatio)
1066
+ ), # before when dynamic losses
1067
+ # where plain (unregister pytree) node classes, we could not traverse
1068
+ # this level. Now that dynamic losses are eqx.Module they can be
1069
+ # traversed by tree map recursion. Hence we need to specify to that
1070
+ # we want to stop at this level
1407
1071
  )
1408
1072
  mse_dyn_loss = jax.tree_util.tree_reduce(
1409
1073
  lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
@@ -1418,11 +1082,10 @@ class SystemLossPDE:
1418
1082
  "boundary_loss": "*",
1419
1083
  "observations": "*",
1420
1084
  "initial_condition": "*",
1421
- "sobolev": "*",
1422
1085
  }
1423
1086
  # we need to do the following for the tree_mapping to work
1424
1087
  if batch.obs_batch_dict is None:
1425
- batch = batch._replace(obs_batch_dict=self.u_dict_with_none)
1088
+ batch = append_obs_batch(batch, self.u_dict_with_none)
1426
1089
  total_loss, res_dict = constraints_system_loss_apply(
1427
1090
  self.u_constraints_dict,
1428
1091
  batch,
@@ -1435,41 +1098,3 @@ class SystemLossPDE:
1435
1098
  total_loss += mse_dyn_loss
1436
1099
  res_dict["dyn_loss"] += mse_dyn_loss
1437
1100
  return total_loss, res_dict
1438
-
1439
- def tree_flatten(self):
1440
- children = (
1441
- self.norm_key_dict,
1442
- self.norm_samples_dict,
1443
- self.initial_condition_fun_dict,
1444
- self._loss_weights,
1445
- )
1446
- aux_data = {
1447
- "u_dict": self.u_dict,
1448
- "dynamic_loss_dict": self.dynamic_loss_dict,
1449
- "norm_borders_dict": self.norm_borders_dict,
1450
- "omega_boundary_fun_dict": self.omega_boundary_fun_dict,
1451
- "omega_boundary_condition_dict": self.omega_boundary_condition_dict,
1452
- "nn_type_dict": self.nn_type_dict,
1453
- "sobolev_m_dict": self.sobolev_m_dict,
1454
- "derivative_keys_dict": self.derivative_keys_dict,
1455
- "obs_slice_dict": self.obs_slice_dict,
1456
- }
1457
- return (children, aux_data)
1458
-
1459
- @classmethod
1460
- def tree_unflatten(cls, aux_data, children):
1461
- (
1462
- norm_key_dict,
1463
- norm_samples_dict,
1464
- initial_condition_fun_dict,
1465
- loss_weights,
1466
- ) = children
1467
- loss_ode = cls(
1468
- loss_weights=loss_weights,
1469
- norm_key_dict=norm_key_dict,
1470
- norm_samples_dict=norm_samples_dict,
1471
- initial_condition_fun_dict=initial_condition_fun_dict,
1472
- **aux_data,
1473
- )
1474
-
1475
- return loss_ode