jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py CHANGED
@@ -1,29 +1,52 @@
1
+ # pylint: disable=unsubscriptable-object, no-member
1
2
  """
2
3
  Main module to implement a PDE loss in jinns
3
4
  """
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
4
8
 
9
+ import abc
10
+ from dataclasses import InitVar, fields
11
+ from typing import TYPE_CHECKING, Dict, Callable
5
12
  import warnings
6
13
  import jax
7
14
  import jax.numpy as jnp
8
- from jax.tree_util import register_pytree_node_class
9
- from jinns.loss._Losses import (
15
+ import equinox as eqx
16
+ from jaxtyping import Float, Array, Key, Int
17
+ from jinns.loss._loss_utils import (
10
18
  dynamic_loss_apply,
11
19
  boundary_condition_apply,
12
20
  normalization_loss_apply,
13
21
  observations_loss_apply,
14
- sobolev_reg_apply,
15
22
  initial_condition_apply,
16
23
  constraints_system_loss_apply,
17
24
  )
18
- from jinns.data._DataGenerators import PDEStatioBatch, PDENonStatioBatch
19
- from jinns.utils._utils import (
25
+ from jinns.data._DataGenerators import (
26
+ append_obs_batch,
27
+ )
28
+ from jinns.parameters._params import (
20
29
  _get_vmap_in_axes_params,
21
- _set_derivatives,
22
30
  _update_eq_params_dict,
23
31
  )
32
+ from jinns.parameters._derivative_keys import (
33
+ _set_derivatives,
34
+ DerivativeKeysPDEStatio,
35
+ DerivativeKeysPDENonStatio,
36
+ )
37
+ from jinns.loss._loss_weights import (
38
+ LossWeightsPDEStatio,
39
+ LossWeightsPDENonStatio,
40
+ LossWeightsPDEDict,
41
+ )
42
+ from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
24
43
  from jinns.utils._pinn import PINN
25
44
  from jinns.utils._spinn import SPINN
26
- from jinns.loss._operators import _sobolev
45
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
46
+
47
+
48
+ if TYPE_CHECKING:
49
+ from jinns.utils._types import *
27
50
 
28
51
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
29
52
  "dirichlet",
@@ -31,375 +54,156 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
31
54
  "vonneumann",
32
55
  ]
33
56
 
34
- _LOSS_WEIGHT_KEYS_PDESTATIO = [
35
- "sobolev",
36
- "observations",
37
- "norm_loss",
38
- "boundary_loss",
39
- "dyn_loss",
40
- ]
41
-
42
- _LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
43
57
 
44
-
45
- @register_pytree_node_class
46
- class LossPDEAbstract:
58
+ class _LossPDEAbstract(eqx.Module):
47
59
  """
48
- Super class for the actual Pinn loss classes. This class should not be
49
- used. It serves for common attributes between LossPDEStatio and
50
- LossPDENonStatio
51
-
52
-
53
- **Note:** LossPDEAbstract is jittable. Hence it implements the tree_flatten() and
54
- tree_unflatten methods.
60
+ Parameters
61
+ ----------
62
+
63
+ loss_weights : LossWeightsPDEStatio | LossWeightsPDENonStatio, default=None
64
+ The loss weights for the differents term : dynamic loss,
65
+ initial condition (if LossWeightsPDENonStatio), boundary conditions if
66
+ any, normalization loss if any and observations if any.
67
+ All fields are set to 1.0 by default.
68
+ derivative_keys : DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio, default=None
69
+ Specify which field of `params` should be differentiated for each
70
+ composant of the total loss. Particularily useful for inverse problems.
71
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
72
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
73
+ is `"nn_params"` for each composant of the loss.
74
+ omega_boundary_fun : Callable | Dict[str, Callable], default=None
75
+ The function to be matched in the border condition (can be None) or a
76
+ dictionary of such functions as values and keys as described
77
+ in `omega_boundary_condition`.
78
+ omega_boundary_condition : str | Dict[str, str], default=None
79
+ Either None (no condition, by default), or a string defining
80
+ the boundary condition (Dirichlet or Von Neumann),
81
+ or a dictionary with such strings as values. In this case,
82
+ the keys are the facets and must be in the following order:
83
+ 1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
84
+ Note that high order boundaries are currently not implemented.
85
+ A value in the dict can be None, this means we do not enforce
86
+ a particular boundary condition on this facet.
87
+ The facet called “xmin”, resp. “xmax” etc., in 2D,
88
+ refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
89
+ omega_boundary_dim : slice | Dict[str, slice], default=None
90
+ Either None, or a slice object or a dictionary of slice objects as
91
+ values and keys as described in `omega_boundary_condition`.
92
+ `omega_boundary_dim` indicates which dimension(s) of the PINN
93
+ will be forced to match the boundary condition.
94
+ Note that it must be a slice and not an integer
95
+ (but a preprocessing of the user provided argument takes care of it)
96
+ norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
97
+ Fixed sample point in the space over which to compute the
98
+ normalization constant. Default is None.
99
+ norm_int_length : float, default=None
100
+ A float. Must be provided if `norm_samples` is provided. The domain area
101
+ (or interval length in 1D) upon which we perform the numerical
102
+ integration. Default None
103
+ obs_slice : slice, default=None
104
+ slice object specifying the begininning/ending of the PINN output
105
+ that is observed (this is then useful for multidim PINN). Default is None.
55
106
  """
56
107
 
57
- def __init__(
58
- self,
59
- u,
60
- loss_weights,
61
- derivative_keys=None,
62
- norm_key=None,
63
- norm_borders=None,
64
- norm_samples=None,
65
- ):
108
+ # NOTE static=True only for leaf attributes that are not valid JAX types
109
+ # (ie. jax.Array cannot be static) and that we do not expect to change
110
+ # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
111
+ derivative_keys: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None = (
112
+ eqx.field(kw_only=True, default=None)
113
+ )
114
+ loss_weights: LossWeightsPDEStatio | LossWeightsPDENonStatio | None = eqx.field(
115
+ kw_only=True, default=None
116
+ )
117
+ omega_boundary_fun: Callable | Dict[str, Callable] | None = eqx.field(
118
+ kw_only=True, default=None, static=True
119
+ )
120
+ omega_boundary_condition: str | Dict[str, str] | None = eqx.field(
121
+ kw_only=True, default=None, static=True
122
+ )
123
+ omega_boundary_dim: slice | Dict[str, slice] | None = eqx.field(
124
+ kw_only=True, default=None, static=True
125
+ )
126
+ norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
127
+ kw_only=True, default=None
128
+ )
129
+ norm_int_length: float | None = eqx.field(kw_only=True, default=None)
130
+ obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
131
+
132
+ def __post_init__(self):
66
133
  """
67
- Parameters
68
- ----------
69
- u
70
- the PINN object
71
- loss_weights
72
- a dictionary with values used to ponderate each term in the loss
73
- function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`
74
- and `observations`
75
- Note that we can have jnp.arrays with the same dimension of
76
- `u` which then ponderates each output of `u`
77
- derivative_keys
78
- A dict of lists of strings. In the dict, the key must correspond to
79
- the loss term keywords. Then each of the values must correspond to keys in the parameter
80
- dictionary (*at top level only of the parameter dictionary*).
81
- It enables selecting the set of parameters
82
- with respect to which the gradients of the dynamic
83
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
84
- keywords, this is what is typically
85
- done in solving forward problems, when we only estimate the
86
- equation solution with a PINN. If some loss terms keywords are
87
- missing we set their value to ["nn_params"] by default for the
88
- same reason
89
- norm_key
90
- Jax random key to draw samples in for the Monte Carlo computation
91
- of the normalization constant. Default is None
92
- norm_borders
93
- tuple of (min, max) of the boundaray values of the space over which
94
- to integrate in the computation of the normalization constant.
95
- A list of tuple for higher dimensional problems. Default None.
96
- norm_samples
97
- Fixed sample point in the space over which to compute the
98
- normalization constant. Default is None
99
-
100
- Raises
101
- ------
102
- RuntimeError
103
- When provided an invalid combination of `norm_key`, `norm_borders`
104
- and `norm_samples`. See note below.
105
-
106
- **Note:** If `norm_key` and `norm_borders` and `norm_samples` are `None`
107
- then no normalization loss in enforced.
108
- If `norm_borders` and `norm_samples` are given while
109
- `norm_samples` is `None` then samples are drawn at each loss evaluation.
110
- Otherwise, if `norm_samples` is given, those samples are used.
134
+ Note that neither __init__ or __post_init__ are called when udating a
135
+ Module with eqx.tree_at
111
136
  """
112
-
113
- self.u = u
114
- if derivative_keys is None:
137
+ if self.derivative_keys is None:
115
138
  # be default we only take gradient wrt nn_params
116
- derivative_keys = {
117
- k: ["nn_params"]
118
- for k in [
119
- "dyn_loss",
120
- "boundary_loss",
121
- "norm_loss",
122
- "initial_condition",
123
- "observations",
124
- "sobolev",
125
- ]
126
- }
127
- if isinstance(derivative_keys, list):
128
- # if the user only provided a list, this defines the gradient taken
129
- # for all the loss entries
130
- derivative_keys = {
131
- k: derivative_keys
132
- for k in [
133
- "dyn_loss",
134
- "boundary_loss",
135
- "norm_loss",
136
- "initial_condition",
137
- "observations",
138
- "sobolev",
139
- ]
140
- }
141
-
142
- self.derivative_keys = derivative_keys
143
- self.loss_weights = loss_weights
144
- self.norm_borders = norm_borders
145
- self.norm_key = norm_key
146
- self.norm_samples = norm_samples
147
-
148
- if norm_key is None and norm_borders is None and norm_samples is None:
149
- # if there is None of the 3 above, that means we don't consider
150
- # normalization loss
151
- self.normalization_loss = None
152
- elif (
153
- norm_key is not None and norm_borders is not None and norm_samples is None
154
- ): # this ordering so that by default priority is to given mc_samples
155
- self.norm_sample_method = "generate"
156
- if not isinstance(self.norm_borders[0], tuple):
157
- self.norm_borders = (self.norm_borders,)
158
- self.norm_xmin, self.norm_xmax = [], []
159
- for i, _ in enumerate(self.norm_borders):
160
- self.norm_xmin.append(self.norm_borders[i][0])
161
- self.norm_xmax.append(self.norm_borders[i][1])
162
- self.int_length = jnp.prod(
163
- jnp.array(
164
- [
165
- self.norm_xmax[i] - self.norm_xmin[i]
166
- for i in range(len(self.norm_borders))
167
- ]
168
- )
139
+ self.derivative_keys = (
140
+ DerivativeKeysPDENonStatio()
141
+ if isinstance(self, LossPDENonStatio)
142
+ else DerivativeKeysPDEStatio()
169
143
  )
170
- self.normalization_loss = True
171
- elif norm_samples is None:
172
- raise RuntimeError(
173
- "norm_borders should always provided then either"
174
- " norm_samples (fixed norm_samples) or norm_key (random norm_samples)"
175
- " is required."
176
- )
177
- else:
178
- # ok, we are sure we have norm_samples given by the user
179
- self.norm_sample_method = "user"
180
- if not isinstance(self.norm_borders[0], tuple):
181
- self.norm_borders = (self.norm_borders,)
182
- self.norm_xmin, self.norm_xmax = [], []
183
- for i, _ in enumerate(self.norm_borders):
184
- self.norm_xmin.append(self.norm_borders[i][0])
185
- self.norm_xmax.append(self.norm_borders[i][1])
186
- self.int_length = jnp.prod(
187
- jnp.array(
188
- [
189
- self.norm_xmax[i] - self.norm_xmin[i]
190
- for i in range(len(self.norm_borders))
191
- ]
192
- )
193
- )
194
- self.normalization_loss = True
195
144
 
196
- def get_norm_samples(self):
197
- """
198
- Returns a batch of points in the domain for integration when the
199
- normalization constraint is enforced. The batch of points is either
200
- fixed (provided by the user) or regenerated at each iteration.
201
- """
202
- if self.norm_sample_method == "user":
203
- return self.norm_samples
204
- if self.norm_sample_method == "generate":
205
- ## NOTE TODO CHECK the performances of this for loop
206
- norm_samples = []
207
- for d in range(len(self.norm_borders)):
208
- self.norm_key, subkey = jax.random.split(self.norm_key)
209
- norm_samples.append(
210
- jax.random.uniform(
211
- subkey,
212
- shape=(1000, 1),
213
- minval=self.norm_xmin[d],
214
- maxval=self.norm_xmax[d],
215
- )
216
- )
217
- self.norm_samples = jnp.concatenate(norm_samples, axis=-1)
218
- return self.norm_samples
219
- raise RuntimeError("Problem with the value of self.norm_sample_method")
220
-
221
- def tree_flatten(self):
222
- children = (self.norm_key, self.norm_samples, self.loss_weights)
223
- aux_data = {
224
- "norm_borders": self.norm_borders,
225
- "derivative_keys": self.derivative_keys,
226
- "u": self.u,
227
- }
228
- return (children, aux_data)
229
-
230
- @classmethod
231
- def tree_unflatten(self, aux_data, children):
232
- (norm_key, norm_samples, loss_weights) = children
233
- pls = self(
234
- aux_data["u"],
235
- loss_weights,
236
- aux_data["derivative_keys"],
237
- norm_key,
238
- aux_data["norm_borders"],
239
- norm_samples,
240
- )
241
- return pls
242
-
243
-
244
- @register_pytree_node_class
245
- class LossPDEStatio(LossPDEAbstract):
246
- r"""Loss object for a stationary partial differential equation
247
-
248
- .. math::
249
- \mathcal{N}[u](x) = 0, \forall x \in \Omega
250
-
251
- where :math:`\mathcal{N}[\cdot]` is a differential operator and the
252
- boundary condition is :math:`u(x)=u_b(x)` The additional condition of
253
- integrating to 1 can be included, i.e. :math:`\int u(x)\mathrm{d}x=1`.
254
-
255
-
256
- **Note:** LossPDEStatio is jittable. Hence it implements the tree_flatten() and
257
- tree_unflatten methods.
258
- """
145
+ if self.loss_weights is None:
146
+ self.loss_weights = (
147
+ LossWeightsPDENonStatio()
148
+ if isinstance(self, LossPDENonStatio)
149
+ else LossWeightsPDEStatio()
150
+ )
259
151
 
260
- def __init__(
261
- self,
262
- u,
263
- loss_weights,
264
- dynamic_loss,
265
- derivative_keys=None,
266
- omega_boundary_fun=None,
267
- omega_boundary_condition=None,
268
- omega_boundary_dim=None,
269
- norm_key=None,
270
- norm_borders=None,
271
- norm_samples=None,
272
- sobolev_m=None,
273
- obs_slice=None,
274
- ):
275
- r"""
276
- Parameters
277
- ----------
278
- u
279
- the PINN object
280
- loss_weights
281
- a dictionary with values used to ponderate each term in the loss
282
- function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
283
- `observations` and `sobolev`.
284
- Note that we can have jnp.arrays with the same dimension of
285
- `u` which then ponderates each output of `u`
286
- dynamic_loss
287
- the stationary PDE dynamic part of the loss, basically the differential
288
- operator :math:` \mathcal{N}[u](t)`. Should implement a method
289
- `dynamic_loss.evaluate(t, u, params)`.
290
- Can be None in order to access only some part of the evaluate call
291
- results.
292
- derivative_keys
293
- A dict of lists of strings. In the dict, the key must correspond to
294
- the loss term keywords. Then each of the values must correspond to keys in the parameter
295
- dictionary (*at top level only of the parameter dictionary*).
296
- It enables selecting the set of parameters
297
- with respect to which the gradients of the dynamic
298
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
299
- keywords, this is what is typically
300
- done in solving forward problems, when we only estimate the
301
- equation solution with a PINN. If some loss terms keywords are
302
- missing we set their value to ["nn_params"] by default for the same
303
- reason
304
- omega_boundary_fun
305
- The function to be matched in the border condition (can be None)
306
- or a dictionary of such function. In this case, the keys are the
307
- facets and the values are the functions. The keys must be in the
308
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
309
- "ymin", "ymax"]. Note that high order boundaries are currently not
310
- implemented. A value in the dict can be None, this means we do not
311
- enforce a particular boundary condition on this facet.
312
- The facet called "xmin", resp. "xmax" etc., in 2D,
313
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
314
- omega_boundary_condition
315
- Either None (no condition), or a string defining the boundary
316
- condition e.g. Dirichlet or Von Neumann, or a dictionary of such
317
- strings. In this case, the keys are the
318
- facets and the values are the strings. The keys must be in the
319
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
320
- "ymin", "ymax"]. Note that high order boundaries are currently not
321
- implemented. A value in the dict can be None, this means we do not
322
- enforce a particular boundary condition on this facet.
323
- The facet called "xmin", resp. "xmax" etc., in 2D,
324
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
325
- omega_boundary_dim
326
- Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
327
- the logic of omega_boundary_fun. It indicates which dimension(s) of
328
- the PINN will be forced to match the boundary condition
329
- Note that it must be a slice and not an integer (a preprocessing of the
330
- user provided argument takes care of it)
331
- norm_key
332
- Jax random key to draw samples in for the Monte Carlo computation
333
- of the normalization constant. Default is None
334
- norm_borders
335
- tuple of (min, max) of the boundaray values of the space over which
336
- to integrate in the computation of the normalization constant.
337
- A list of tuple for higher dimensional problems. Default None.
338
- norm_samples
339
- Fixed sample point in the space over which to compute the
340
- normalization constant. Default is None
341
- sobolev_m
342
- An integer. Default is None.
343
- It corresponds to the Sobolev regularization order as proposed in
344
- *Convergence and error analysis of PINNs*,
345
- Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
346
- obs_slice
347
- slice object specifying the begininning/ending
348
- slice of u output(s) that is observed (this is then useful for
349
- multidim PINN). Default is None.
350
-
351
-
352
- Raises
353
- ------
354
- ValueError
355
- If conditions on omega_boundary_condition and omega_boundary_fun
356
- are not respected
357
- """
152
+ if self.obs_slice is None:
153
+ self.obs_slice = jnp.s_[...]
358
154
 
359
- super().__init__(
360
- u, loss_weights, derivative_keys, norm_key, norm_borders, norm_samples
361
- )
155
+ if (
156
+ isinstance(self.omega_boundary_fun, dict)
157
+ and not isinstance(self.omega_boundary_condition, dict)
158
+ ) or (
159
+ not isinstance(self.omega_boundary_fun, dict)
160
+ and isinstance(self.omega_boundary_condition, dict)
161
+ ):
162
+ raise ValueError(
163
+ "if one of self.omega_boundary_fun or "
164
+ "self.omega_boundary_condition is dict, the other should be too."
165
+ )
362
166
 
363
- if omega_boundary_condition is None or omega_boundary_fun is None:
167
+ if self.omega_boundary_condition is None or self.omega_boundary_fun is None:
364
168
  warnings.warn(
365
169
  "Missing boundary function or no boundary condition."
366
170
  "Boundary function is thus ignored."
367
171
  )
368
172
  else:
369
- if isinstance(omega_boundary_condition, dict):
370
- for _, v in omega_boundary_condition.items():
173
+ if isinstance(self.omega_boundary_condition, dict):
174
+ for _, v in self.omega_boundary_condition.items():
371
175
  if v is not None and not any(
372
176
  v.lower() in s for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
373
177
  ):
374
178
  raise NotImplementedError(
375
- f"The boundary condition {omega_boundary_condition} is not"
179
+ f"The boundary condition {self.omega_boundary_condition} is not"
376
180
  f"implemented yet. Try one of :"
377
181
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
378
182
  )
379
183
  else:
380
184
  if not any(
381
- omega_boundary_condition.lower() in s
185
+ self.omega_boundary_condition.lower() in s
382
186
  for s in _IMPLEMENTED_BOUNDARY_CONDITIONS
383
187
  ):
384
188
  raise NotImplementedError(
385
- f"The boundary condition {omega_boundary_condition} is not"
189
+ f"The boundary condition {self.omega_boundary_condition} is not"
386
190
  f"implemented yet. Try one of :"
387
191
  f"{_IMPLEMENTED_BOUNDARY_CONDITIONS}."
388
192
  )
389
- if isinstance(omega_boundary_fun, dict) and isinstance(
390
- omega_boundary_condition, dict
193
+ if isinstance(self.omega_boundary_fun, dict) and isinstance(
194
+ self.omega_boundary_condition, dict
391
195
  ):
392
196
  if (
393
197
  not (
394
- list(omega_boundary_fun.keys()) == ["xmin", "xmax"]
395
- and list(omega_boundary_condition.keys())
198
+ list(self.omega_boundary_fun.keys()) == ["xmin", "xmax"]
199
+ and list(self.omega_boundary_condition.keys())
396
200
  == ["xmin", "xmax"]
397
201
  )
398
202
  ) or (
399
203
  not (
400
- list(omega_boundary_fun.keys())
204
+ list(self.omega_boundary_fun.keys())
401
205
  == ["xmin", "xmax", "ymin", "ymax"]
402
- and list(omega_boundary_condition.keys())
206
+ and list(self.omega_boundary_condition.keys())
403
207
  == ["xmin", "xmax", "ymin", "ymax"]
404
208
  )
405
209
  ):
@@ -408,10 +212,6 @@ class LossPDEStatio(LossPDEAbstract):
408
212
  "boundary condition dictionaries is incorrect"
409
213
  )
410
214
 
411
- self.omega_boundary_fun = omega_boundary_fun
412
- self.omega_boundary_condition = omega_boundary_condition
413
-
414
- self.omega_boundary_dim = omega_boundary_dim
415
215
  if isinstance(self.omega_boundary_fun, dict):
416
216
  if self.omega_boundary_dim is None:
417
217
  self.omega_boundary_dim = {
@@ -440,44 +240,139 @@ class LossPDEStatio(LossPDEAbstract):
440
240
  self.omega_boundary_dim : self.omega_boundary_dim + 1
441
241
  ]
442
242
  if not isinstance(self.omega_boundary_dim, slice):
443
- raise ValueError("self.omega_boundary_dim must be a jnp.s_" " object")
243
+ raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
444
244
 
445
- self.dynamic_loss = dynamic_loss
245
+ if self.norm_samples is not None and self.norm_int_length is None:
246
+ raise ValueError("self.norm_samples and norm_int_length must be provided")
446
247
 
447
- self.sobolev_m = sobolev_m
448
- if self.sobolev_m is not None:
449
- self.sobolev_reg = _sobolev(
450
- self.u, self.sobolev_m
451
- ) # we return a function, that way
452
- # the order of sobolev_m is static and the conditional in the recursive
453
- # function is properly set
454
- else:
455
- self.sobolev_reg = None
248
+ @abc.abstractmethod
249
+ def evaluate(
250
+ self: eqx.Module,
251
+ params: Params,
252
+ batch: PDEStatioBatch | PDENonStatioBatch,
253
+ ) -> tuple[Float, dict]:
254
+ raise NotImplementedError
456
255
 
457
- for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
458
- if k not in self.loss_weights.keys():
459
- self.loss_weights[k] = 0
460
256
 
461
- if (
462
- isinstance(self.omega_boundary_fun, dict)
463
- and not isinstance(self.omega_boundary_condition, dict)
464
- ) or (
465
- not isinstance(self.omega_boundary_fun, dict)
466
- and isinstance(self.omega_boundary_condition, dict)
467
- ):
468
- raise ValueError(
469
- "if one of self.omega_boundary_fun or "
470
- "self.omega_boundary_condition is dict, the other should be too."
471
- )
257
+ class LossPDEStatio(_LossPDEAbstract):
258
+ r"""Loss object for a stationary partial differential equation
472
259
 
473
- self.obs_slice = obs_slice
474
- if self.obs_slice is None:
475
- self.obs_slice = jnp.s_[...]
260
+ $$
261
+ \mathcal{N}[u](x) = 0, \forall x \in \Omega
262
+ $$
263
+
264
+ where $\mathcal{N}[\cdot]$ is a differential operator and the
265
+ boundary condition is $u(x)=u_b(x)$ The additional condition of
266
+ integrating to 1 can be included, i.e. $\int u(x)\mathrm{d}x=1$.
267
+
268
+ Parameters
269
+ ----------
270
+ u : eqx.Module
271
+ the PINN
272
+ dynamic_loss : DynamicLoss
273
+ the stationary PDE dynamic part of the loss, basically the differential
274
+ operator $\mathcal{N}[u](x)$. Should implement a method
275
+ `dynamic_loss.evaluate(x, u, params)`.
276
+ Can be None in order to access only some part of the evaluate call
277
+ results.
278
+ key : Key
279
+ A JAX PRNG Key for the loss class treated as an attribute. Default is
280
+ None. This field is provided for future developments and additional
281
+ losses that might need some randomness. Note that special care must be
282
+ taken when splitting the key because in-place updates are forbidden in
283
+ eqx.Modules.
284
+ loss_weights : LossWeightsPDEStatio, default=None
285
+ The loss weights for the differents term : dynamic loss,
286
+ boundary conditions if any, normalization loss if any and
287
+ observations if any.
288
+ All fields are set to 1.0 by default.
289
+ derivative_keys : DerivativeKeysPDEStatio, default=None
290
+ Specify which field of `params` should be differentiated for each
291
+ composant of the total loss. Particularily useful for inverse problems.
292
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
293
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
294
+ is `"nn_params"` for each composant of the loss.
295
+ omega_boundary_fun : Callable | Dict[str, Callable], default=None
296
+ The function to be matched in the border condition (can be None) or a
297
+ dictionary of such functions as values and keys as described
298
+ in `omega_boundary_condition`.
299
+ omega_boundary_condition : str | Dict[str, str], default=None
300
+ Either None (no condition, by default), or a string defining
301
+ the boundary condition (Dirichlet or Von Neumann),
302
+ or a dictionary with such strings as values. In this case,
303
+ the keys are the facets and must be in the following order:
304
+ 1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
305
+ Note that high order boundaries are currently not implemented.
306
+ A value in the dict can be None, this means we do not enforce
307
+ a particular boundary condition on this facet.
308
+ The facet called “xmin”, resp. “xmax” etc., in 2D,
309
+ refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
310
+ omega_boundary_dim : slice | Dict[str, slice], default=None
311
+ Either None, or a slice object or a dictionary of slice objects as
312
+ values and keys as described in `omega_boundary_condition`.
313
+ `omega_boundary_dim` indicates which dimension(s) of the PINN
314
+ will be forced to match the boundary condition.
315
+ Note that it must be a slice and not an integer
316
+ (but a preprocessing of the user provided argument takes care of it)
317
+ norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
318
+ Fixed sample point in the space over which to compute the
319
+ normalization constant. Default is None.
320
+ norm_int_length : float, default=None
321
+ A float. Must be provided if `norm_samples` is provided. The domain area
322
+ (or interval length in 1D) upon which we perform the numerical
323
+ integration. Default None
324
+ obs_slice : slice, default=None
325
+ slice object specifying the begininning/ending of the PINN output
326
+ that is observed (this is then useful for multidim PINN). Default is None.
327
+
328
+
329
+ Raises
330
+ ------
331
+ ValueError
332
+ If conditions on omega_boundary_condition and omega_boundary_fun
333
+ are not respected
334
+ """
335
+
336
+ # NOTE static=True only for leaf attributes that are not valid JAX types
337
+ # (ie. jax.Array cannot be static) and that we do not expect to change
338
+
339
+ u: eqx.Module
340
+ dynamic_loss: DynamicLoss | None
341
+ key: Key | None = eqx.field(kw_only=True, default=None)
342
+
343
+ vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
344
+
345
+ def __post_init__(self):
346
+ """
347
+ Note that neither __init__ or __post_init__ are called when udating a
348
+ Module with eqx.tree_at!
349
+ """
350
+ super().__post_init__() # because __init__ or __post_init__ of Base
351
+ # class is not automatically called
352
+
353
+ self.vmap_in_axes = (0,) # for x only here
354
+
355
+ def _get_dynamic_loss_batch(
356
+ self, batch: PDEStatioBatch
357
+ ) -> tuple[Float[Array, "batch_size dimension"]]:
358
+ return (batch.inside_batch,)
359
+
360
+ def _get_normalization_loss_batch(
361
+ self, _
362
+ ) -> Float[Array, "nb_norm_samples dimension"]:
363
+ return (self.norm_samples,)
364
+
365
+ def _get_observations_loss_batch(
366
+ self, batch: PDEStatioBatch
367
+ ) -> Float[Array, "batch_size obs_dim"]:
368
+ return (batch.obs_batch_dict["pinn_in"],)
476
369
 
477
370
  def __call__(self, *args, **kwargs):
478
371
  return self.evaluate(*args, **kwargs)
479
372
 
480
- def evaluate(self, params, batch):
373
+ def evaluate(
374
+ self, params: Params, batch: PDEStatioBatch
375
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
481
376
  """
482
377
  Evaluate the loss function at a batch of points for given parameters.
483
378
 
@@ -485,22 +380,14 @@ class LossPDEStatio(LossPDEAbstract):
485
380
  Parameters
486
381
  ---------
487
382
  params
488
- The dictionary of parameters of the model.
489
- Typically, it is a dictionary of
490
- dictionaries: `eq_params` and `nn_params``, respectively the
491
- differential equation parameters and the neural network parameter
383
+ Parameters at which the loss is evaluated
492
384
  batch
493
- A PDEStatioBatch object.
494
- Such a named tuple is composed of a batch of points in the
385
+ Composed of a batch of points in the
495
386
  domain, a batch of points in the domain
496
387
  border and an optional additional batch of parameters (eg. for
497
388
  metamodeling) and an optional additional batch of observed
498
389
  inputs/outputs/parameters
499
390
  """
500
- omega_batch, _ = batch.inside_batch, batch.border_batch
501
-
502
- vmap_in_axes_x = (0,)
503
-
504
391
  # Retrieve the optional eq_params_batch
505
392
  # and update eq_params with the latter
506
393
  # and update vmap_in_axes
@@ -511,44 +398,41 @@ class LossPDEStatio(LossPDEAbstract):
511
398
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
512
399
 
513
400
  # dynamic part
514
- params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
515
401
  if self.dynamic_loss is not None:
516
402
  mse_dyn_loss = dynamic_loss_apply(
517
403
  self.dynamic_loss.evaluate,
518
404
  self.u,
519
- (omega_batch,),
520
- params_,
521
- vmap_in_axes_x + vmap_in_axes_params,
522
- self.loss_weights["dyn_loss"],
405
+ self._get_dynamic_loss_batch(batch),
406
+ _set_derivatives(params, self.derivative_keys.dyn_loss),
407
+ self.vmap_in_axes + vmap_in_axes_params,
408
+ self.loss_weights.dyn_loss,
523
409
  )
524
410
  else:
525
411
  mse_dyn_loss = jnp.array(0.0)
526
412
 
527
413
  # normalization part
528
- params_ = _set_derivatives(params, "norm_loss", self.derivative_keys)
529
- if self.normalization_loss is not None:
414
+ if self.norm_samples is not None:
530
415
  mse_norm_loss = normalization_loss_apply(
531
416
  self.u,
532
- (self.get_norm_samples(),),
533
- params_,
534
- vmap_in_axes_x + vmap_in_axes_params,
535
- self.int_length,
536
- self.loss_weights["norm_loss"],
417
+ self._get_normalization_loss_batch(batch),
418
+ _set_derivatives(params, self.derivative_keys.norm_loss),
419
+ self.vmap_in_axes + vmap_in_axes_params,
420
+ self.norm_int_length,
421
+ self.loss_weights.norm_loss,
537
422
  )
538
423
  else:
539
424
  mse_norm_loss = jnp.array(0.0)
540
425
 
541
426
  # boundary part
542
- params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
543
427
  if self.omega_boundary_condition is not None:
544
428
  mse_boundary_loss = boundary_condition_apply(
545
429
  self.u,
546
430
  batch,
547
- params_,
431
+ _set_derivatives(params, self.derivative_keys.boundary_loss),
548
432
  self.omega_boundary_fun,
549
433
  self.omega_boundary_condition,
550
434
  self.omega_boundary_dim,
551
- self.loss_weights["boundary_loss"],
435
+ self.loss_weights.boundary_loss,
552
436
  )
553
437
  else:
554
438
  mse_boundary_loss = jnp.array(0.0)
@@ -558,40 +442,21 @@ class LossPDEStatio(LossPDEAbstract):
558
442
  # update params with the batches of observed params
559
443
  params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
560
444
 
561
- params_ = _set_derivatives(params, "observations", self.derivative_keys)
562
445
  mse_observation_loss = observations_loss_apply(
563
446
  self.u,
564
- (batch.obs_batch_dict["pinn_in"],),
565
- params_,
566
- vmap_in_axes_x + vmap_in_axes_params,
447
+ self._get_observations_loss_batch(batch),
448
+ _set_derivatives(params, self.derivative_keys.observations),
449
+ self.vmap_in_axes + vmap_in_axes_params,
567
450
  batch.obs_batch_dict["val"],
568
- self.loss_weights["observations"],
451
+ self.loss_weights.observations,
569
452
  self.obs_slice,
570
453
  )
571
454
  else:
572
455
  mse_observation_loss = jnp.array(0.0)
573
456
 
574
- # Sobolev regularization
575
- params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
576
- if self.sobolev_reg is not None:
577
- mse_sobolev_loss = sobolev_reg_apply(
578
- self.u,
579
- (omega_batch,),
580
- params_,
581
- vmap_in_axes_x + vmap_in_axes_params,
582
- self.sobolev_reg,
583
- self.loss_weights["sobolev"],
584
- )
585
- else:
586
- mse_sobolev_loss = jnp.array(0.0)
587
-
588
457
  # total loss
589
458
  total_loss = (
590
- mse_dyn_loss
591
- + mse_norm_loss
592
- + mse_boundary_loss
593
- + mse_observation_loss
594
- + mse_sobolev_loss
459
+ mse_dyn_loss + mse_norm_loss + mse_boundary_loss + mse_observation_loss
595
460
  )
596
461
  return total_loss, (
597
462
  {
@@ -599,205 +464,143 @@ class LossPDEStatio(LossPDEAbstract):
599
464
  "norm_loss": mse_norm_loss,
600
465
  "boundary_loss": mse_boundary_loss,
601
466
  "observations": mse_observation_loss,
602
- "sobolev": mse_sobolev_loss,
603
467
  "initial_condition": jnp.array(0.0), # for compatibility in the
604
468
  # tree_map of SystemLoss
605
469
  }
606
470
  )
607
471
 
608
- def tree_flatten(self):
609
- children = (self.norm_key, self.norm_samples, self.loss_weights)
610
- aux_data = {
611
- "u": self.u,
612
- "dynamic_loss": self.dynamic_loss,
613
- "derivative_keys": self.derivative_keys,
614
- "omega_boundary_fun": self.omega_boundary_fun,
615
- "omega_boundary_condition": self.omega_boundary_condition,
616
- "omega_boundary_dim": self.omega_boundary_dim,
617
- "norm_borders": self.norm_borders,
618
- "sobolev_m": self.sobolev_m,
619
- "obs_slice": self.obs_slice,
620
- }
621
- return (children, aux_data)
622
-
623
- @classmethod
624
- def tree_unflatten(cls, aux_data, children):
625
- (norm_key, norm_samples, loss_weights) = children
626
- pls = cls(
627
- aux_data["u"],
628
- loss_weights,
629
- aux_data["dynamic_loss"],
630
- aux_data["derivative_keys"],
631
- aux_data["omega_boundary_fun"],
632
- aux_data["omega_boundary_condition"],
633
- aux_data["omega_boundary_dim"],
634
- norm_key,
635
- aux_data["norm_borders"],
636
- norm_samples,
637
- aux_data["sobolev_m"],
638
- aux_data["obs_slice"],
639
- )
640
- return pls
641
472
 
642
-
643
- @register_pytree_node_class
644
473
  class LossPDENonStatio(LossPDEStatio):
645
474
  r"""Loss object for a stationary partial differential equation
646
475
 
647
- .. math::
476
+ $$
648
477
  \mathcal{N}[u](t, x) = 0, \forall t \in I, \forall x \in \Omega
478
+ $$
649
479
 
650
- where :math:`\mathcal{N}[\cdot]` is a differential operator.
651
- The boundary condition is :math:`u(t, x)=u_b(t, x),\forall
652
- x\in\delta\Omega, \forall t`.
653
- The initial condition is :math:`u(0, x)=u_0(x), \forall x\in\Omega`
480
+ where $\mathcal{N}[\cdot]$ is a differential operator.
481
+ The boundary condition is $u(t, x)=u_b(t, x),\forall
482
+ x\in\delta\Omega, \forall t$.
483
+ The initial condition is $u(0, x)=u_0(x), \forall x\in\Omega$
654
484
  The additional condition of
655
- integrating to 1 can be included, i.e., :math:`\int u(t, x)\mathrm{d}x=1`.
656
-
485
+ integrating to 1 can be included, i.e., $\int u(t, x)\mathrm{d}x=1$.
486
+
487
+ Parameters
488
+ ----------
489
+ u : eqx.Module
490
+ the PINN
491
+ dynamic_loss : DynamicLoss
492
+ the non stationary PDE dynamic part of the loss, basically the differential
493
+ operator $\mathcal{N}[u](t, x)$. Should implement a method
494
+ `dynamic_loss.evaluate(t, x, u, params)`.
495
+ Can be None in order to access only some part of the evaluate call
496
+ results.
497
+ key : Key
498
+ A JAX PRNG Key for the loss class treated as an attribute. Default is
499
+ None. This field is provided for future developments and additional
500
+ losses that might need some randomness. Note that special care must be
501
+ taken when splitting the key because in-place updates are forbidden in
502
+ eqx.Modules.
503
+ reason
504
+ loss_weights : LossWeightsPDENonStatio, default=None
505
+ The loss weights for the differents term : dynamic loss,
506
+ boundary conditions if any, initial condition, normalization loss if any and
507
+ observations if any.
508
+ All fields are set to 1.0 by default.
509
+ derivative_keys : DerivativeKeysPDENonStatio, default=None
510
+ Specify which field of `params` should be differentiated for each
511
+ composant of the total loss. Particularily useful for inverse problems.
512
+ Fields can be "nn_params", "eq_params" or "both". Those that should not
513
+ be updated will have a `jax.lax.stop_gradient` called on them. Default
514
+ is `"nn_params"` for each composant of the loss.
515
+ omega_boundary_fun : Callable | Dict[str, Callable], default=None
516
+ The function to be matched in the border condition (can be None) or a
517
+ dictionary of such functions as values and keys as described
518
+ in `omega_boundary_condition`.
519
+ omega_boundary_condition : str | Dict[str, str], default=None
520
+ Either None (no condition, by default), or a string defining
521
+ the boundary condition (Dirichlet or Von Neumann),
522
+ or a dictionary with such strings as values. In this case,
523
+ the keys are the facets and must be in the following order:
524
+ 1D -> [“xmin”, “xmax”], 2D -> [“xmin”, “xmax”, “ymin”, “ymax”].
525
+ Note that high order boundaries are currently not implemented.
526
+ A value in the dict can be None, this means we do not enforce
527
+ a particular boundary condition on this facet.
528
+ The facet called “xmin”, resp. “xmax” etc., in 2D,
529
+ refers to the set of 2D points with fixed “xmin”, resp. “xmax”, etc.
530
+ omega_boundary_dim : slice | Dict[str, slice], default=None
531
+ Either None, or a slice object or a dictionary of slice objects as
532
+ values and keys as described in `omega_boundary_condition`.
533
+ `omega_boundary_dim` indicates which dimension(s) of the PINN
534
+ will be forced to match the boundary condition.
535
+ Note that it must be a slice and not an integer
536
+ (but a preprocessing of the user provided argument takes care of it)
537
+ norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
538
+ Fixed sample point in the space over which to compute the
539
+ normalization constant. Default is None.
540
+ norm_int_length : float, default=None
541
+ A float. Must be provided if `norm_samples` is provided. The domain area
542
+ (or interval length in 1D) upon which we perform the numerical
543
+ integration. Default None
544
+ obs_slice : slice, default=None
545
+ slice object specifying the begininning/ending of the PINN output
546
+ that is observed (this is then useful for multidim PINN). Default is None.
547
+ initial_condition_fun : Callable, default=None
548
+ A function representing the temporal initial condition. If None
549
+ (default) then no initial condition is applied
657
550
 
658
- **Note:** LossPDENonStatio is jittable. Hence it implements the tree_flatten() and
659
- tree_unflatten methods.
660
551
  """
661
552
 
662
- def __init__(
663
- self,
664
- u,
665
- loss_weights,
666
- dynamic_loss,
667
- derivative_keys=None,
668
- omega_boundary_fun=None,
669
- omega_boundary_condition=None,
670
- omega_boundary_dim=None,
671
- initial_condition_fun=None,
672
- norm_key=None,
673
- norm_borders=None,
674
- norm_samples=None,
675
- sobolev_m=None,
676
- obs_slice=None,
677
- ):
678
- r"""
679
- Parameters
680
- ----------
681
- u
682
- the PINN object
683
- loss_weights
684
- dictionary of values for loss term ponderation
685
- Note that we can have jnp.arrays with the same dimension of
686
- `u` which then ponderates each output of `u`
687
- dynamic_loss
688
- A Dynamic loss object whose evaluate method corresponds to the
689
- dynamic term in the loss
690
- Can be None in order to access only some part of the evaluate call
691
- results.
692
- derivative_keys
693
- A dict of lists of strings. In the dict, the key must correspond to
694
- the loss term keywords. Then each of the values must correspond to keys in the parameter
695
- dictionary (*at top level only of the parameter dictionary*).
696
- It enables selecting the set of parameters
697
- with respect to which the gradients of the dynamic
698
- loss are computed. If nothing is provided, we set ["nn_params"] for all loss term
699
- keywords, this is what is typically
700
- done in solving forward problems, when we only estimate the
701
- equation solution with a PINN. If some loss terms keywords are
702
- missing we set their value to ["nn_params"] by default for the same
703
- reason
704
- omega_boundary_fun
705
- The function to be matched in the border condition (can be None)
706
- or a dictionary of such function. In this case, the keys are the
707
- facets and the values are the functions. The keys must be in the
708
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
709
- "ymin", "ymax"]. Note that high order boundaries are currently not
710
- implemented. A value in the dict can be None, this means we do not
711
- enforce a particular boundary condition on this facet.
712
- The facet called "xmin", resp. "xmax" etc., in 2D,
713
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
714
- omega_boundary_condition
715
- Either None (no condition), or a string defining the boundary
716
- condition e.g. Dirichlet or Von Neumann, or a dictionary of such
717
- strings. In this case, the keys are the
718
- facets and the values are the strings. The keys must be in the
719
- following order: 1D -> ["xmin", "xmax"], 2D -> ["xmin", "xmax",
720
- "ymin", "ymax"]. Note that high order boundaries are currently not
721
- implemented. A value in the dict can be None, this means we do not
722
- enforce a particular boundary condition on this facet.
723
- The facet called "xmin", resp. "xmax" etc., in 2D,
724
- refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
725
- omega_boundary_dim
726
- Either None, or a jnp.s\_ or a dict of jnp.s\_ with keys following
727
- the logic of omega_boundary_fun. It indicates which dimension(s) of
728
- the PINN will be forced to match the boundary condition
729
- Note that it must be a slice and not an integer (a preprocessing of the
730
- user provided argument takes care of it)
731
- initial_condition_fun
732
- A function representing the temporal initial condition. If None
733
- (default) then no initial condition is applied
734
- norm_key
735
- Jax random key to draw samples in for the Monte Carlo computation
736
- of the normalization constant. Default is None
737
- norm_borders
738
- tuple of (min, max) of the boundaray values of the space over which
739
- to integrate in the computation of the normalization constant.
740
- A list of tuple for higher dimensional problems. Default None.
741
- norm_samples
742
- Fixed sample point in the space over which to compute the
743
- normalization constant. Default is None
744
- sobolev_m
745
- An integer. Default is None.
746
- It corresponds to the Sobolev regularization order as proposed in
747
- *Convergence and error analysis of PINNs*,
748
- Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
749
- obs_slice
750
- slice object specifying the begininning/ending
751
- slice of u output(s) that is observed (this is then useful for
752
- multidim PINN). Default is None.
753
-
553
+ # NOTE static=True only for leaf attributes that are not valid JAX types
554
+ # (ie. jax.Array cannot be static) and that we do not expect to change
555
+ initial_condition_fun: Callable | None = eqx.field(
556
+ kw_only=True, default=None, static=True
557
+ )
754
558
 
559
+ def __post_init__(self):
560
+ """
561
+ Note that neither __init__ or __post_init__ are called when udating a
562
+ Module with eqx.tree_at!
755
563
  """
564
+ super().__post_init__() # because __init__ or __post_init__ of Base
565
+ # class is not automatically called
756
566
 
757
- super().__init__(
758
- u,
759
- loss_weights,
760
- dynamic_loss,
761
- derivative_keys,
762
- omega_boundary_fun,
763
- omega_boundary_condition,
764
- omega_boundary_dim,
765
- norm_key,
766
- norm_borders,
767
- norm_samples,
768
- sobolev_m=sobolev_m,
769
- obs_slice=obs_slice,
770
- )
771
- if initial_condition_fun is None:
567
+ self.vmap_in_axes = (0, 0) # for t and x
568
+
569
+ if self.initial_condition_fun is None:
772
570
  warnings.warn(
773
571
  "Initial condition wasn't provided. Be sure to cover for that"
774
572
  "case (e.g by. hardcoding it into the PINN output)."
775
573
  )
776
- self.initial_condition_fun = initial_condition_fun
777
-
778
- self.sobolev_m = sobolev_m
779
- if self.sobolev_m is not None:
780
- # This overwrite the wrongly initialized self.sobolev_reg with
781
- # statio=True in the LossPDEStatio init
782
- self.sobolev_reg = _sobolev(self.u, self.sobolev_m, statio=False)
783
- # we return a function, that way
784
- # the order of sobolev_m is static and the conditional in the recursive
785
- # function is properly set
786
- else:
787
- self.sobolev_reg = None
788
574
 
789
- for k in _LOSS_WEIGHT_KEYS_PDENONSTATIO:
790
- if k not in self.loss_weights.keys():
791
- self.loss_weights[k] = 0
575
+ def _get_dynamic_loss_batch(
576
+ self, batch: PDENonStatioBatch
577
+ ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
578
+ times_batch = batch.times_x_inside_batch[:, 0:1]
579
+ omega_batch = batch.times_x_inside_batch[:, 1:]
580
+ return (times_batch, omega_batch)
581
+
582
+ def _get_normalization_loss_batch(
583
+ self, batch: PDENonStatioBatch
584
+ ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
585
+ return (
586
+ batch.times_x_inside_batch[:, 0:1],
587
+ self.norm_samples,
588
+ )
589
+
590
+ def _get_observations_loss_batch(
591
+ self, batch: PDENonStatioBatch
592
+ ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
593
+ return (
594
+ batch.obs_batch_dict["pinn_in"][:, 0:1],
595
+ batch.obs_batch_dict["pinn_in"][:, 1:],
596
+ )
792
597
 
793
598
  def __call__(self, *args, **kwargs):
794
599
  return self.evaluate(*args, **kwargs)
795
600
 
796
601
  def evaluate(
797
- self,
798
- params,
799
- batch,
800
- ):
602
+ self, params: Params, batch: PDENonStatioBatch
603
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
801
604
  """
802
605
  Evaluate the loss function at a batch of points for given parameters.
803
606
 
@@ -805,203 +608,55 @@ class LossPDENonStatio(LossPDEStatio):
805
608
  Parameters
806
609
  ---------
807
610
  params
808
- The dictionary of parameters of the model.
809
- Typically, it is a dictionary of
810
- dictionaries: `eq_params` and `nn_params`, respectively the
811
- differential equation parameters and the neural network parameter
611
+ Parameters at which the loss is evaluated
812
612
  batch
813
- A PDENonStatioBatch object.
814
- Such a named tuple is composed of a batch of points in
613
+ Composed of a batch of points in
815
614
  the domain, a batch of points in the domain
816
615
  border, a batch of time points and an optional additional batch
817
616
  of parameters (eg. for metamodeling) and an optional additional batch of observed
818
617
  inputs/outputs/parameters
819
618
  """
820
619
 
821
- omega_batch, omega_border_batch, times_batch = (
822
- batch.inside_batch,
823
- batch.border_batch,
824
- batch.temporal_batch,
825
- )
826
- n = omega_batch.shape[0]
827
- nt = times_batch.shape[0]
828
- times_batch = times_batch.reshape(nt, 1)
829
-
830
- def rep_times(k):
831
- return jnp.repeat(times_batch, k, axis=0)
832
-
833
- vmap_in_axes_x_t = (0, 0)
620
+ omega_batch = batch.times_x_inside_batch[:, 1:]
834
621
 
835
622
  # Retrieve the optional eq_params_batch
836
623
  # and update eq_params with the latter
837
624
  # and update vmap_in_axes
838
625
  if batch.param_batch_dict is not None:
839
- eq_params_batch_dict = batch.param_batch_dict
840
-
841
- # feed the eq_params with the batch
842
- for k in eq_params_batch_dict.keys():
843
- params["eq_params"][k] = eq_params_batch_dict[k]
626
+ # update eq_params with the batches of generated params
627
+ params = _update_eq_params_dict(params, batch.param_batch_dict)
844
628
 
845
629
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
846
630
 
847
- if isinstance(self.u, PINN):
848
- omega_batch = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
849
- times_batch = rep_times(n) # it is repeated
850
-
851
- # dynamic part
852
- params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
853
- if self.dynamic_loss is not None:
854
- mse_dyn_loss = dynamic_loss_apply(
855
- self.dynamic_loss.evaluate,
856
- self.u,
857
- (times_batch, omega_batch),
858
- params_,
859
- vmap_in_axes_x_t + vmap_in_axes_params,
860
- self.loss_weights["dyn_loss"],
861
- )
862
- else:
863
- mse_dyn_loss = jnp.array(0.0)
864
-
865
- # normalization part
866
- params_ = _set_derivatives(params, "norm_loss", self.derivative_keys)
867
- if self.normalization_loss is not None:
868
- mse_norm_loss = normalization_loss_apply(
869
- self.u,
870
- (times_batch, self.get_norm_samples()),
871
- params_,
872
- vmap_in_axes_x_t + vmap_in_axes_params,
873
- self.int_length,
874
- self.loss_weights["norm_loss"],
875
- )
876
- else:
877
- mse_norm_loss = jnp.array(0.0)
878
-
879
- # boundary part
880
- params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
881
- if self.omega_boundary_fun is not None:
882
- mse_boundary_loss = boundary_condition_apply(
883
- self.u,
884
- batch,
885
- params_,
886
- self.omega_boundary_fun,
887
- self.omega_boundary_condition,
888
- self.omega_boundary_dim,
889
- self.loss_weights["boundary_loss"],
890
- )
891
- else:
892
- mse_boundary_loss = jnp.array(0.0)
631
+ # For mse_dyn_loss, mse_norm_loss, mse_boundary_loss,
632
+ # mse_observation_loss we use the evaluate from parent class
633
+ partial_mse, partial_mse_terms = super().evaluate(params, batch)
893
634
 
894
635
  # initial condition
895
- params_ = _set_derivatives(params, "initial_condition", self.derivative_keys)
896
636
  if self.initial_condition_fun is not None:
897
637
  mse_initial_condition = initial_condition_apply(
898
638
  self.u,
899
639
  omega_batch,
900
- params_,
640
+ _set_derivatives(params, self.derivative_keys.initial_condition),
901
641
  (0,) + vmap_in_axes_params,
902
642
  self.initial_condition_fun,
903
- n,
904
- self.loss_weights["initial_condition"],
643
+ omega_batch.shape[0],
644
+ self.loss_weights.initial_condition,
905
645
  )
906
646
  else:
907
647
  mse_initial_condition = jnp.array(0.0)
908
648
 
909
- # Observation mse
910
- if batch.obs_batch_dict is not None:
911
- # update params with the batches of observed params
912
- params = _update_eq_params_dict(params, batch.obs_batch_dict["eq_params"])
913
-
914
- params_ = _set_derivatives(params, "observations", self.derivative_keys)
915
- mse_observation_loss = observations_loss_apply(
916
- self.u,
917
- (
918
- batch.obs_batch_dict["pinn_in"][:, 0:1],
919
- batch.obs_batch_dict["pinn_in"][:, 1:],
920
- ),
921
- params_,
922
- vmap_in_axes_x_t + vmap_in_axes_params,
923
- batch.obs_batch_dict["val"],
924
- self.loss_weights["observations"],
925
- self.obs_slice,
926
- )
927
- else:
928
- mse_observation_loss = jnp.array(0.0)
929
-
930
- # Sobolev regularization
931
- params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
932
- if self.sobolev_reg is not None:
933
- mse_sobolev_loss = sobolev_reg_apply(
934
- self.u,
935
- (omega_batch, times_batch),
936
- params_,
937
- vmap_in_axes_x_t + vmap_in_axes_params,
938
- self.sobolev_reg,
939
- self.loss_weights["sobolev"],
940
- )
941
- else:
942
- mse_sobolev_loss = jnp.array(0.0)
943
-
944
649
  # total loss
945
- total_loss = (
946
- mse_dyn_loss
947
- + mse_norm_loss
948
- + mse_boundary_loss
949
- + mse_initial_condition
950
- + mse_observation_loss
951
- + mse_sobolev_loss
952
- )
953
-
954
- return total_loss, (
955
- {
956
- "dyn_loss": mse_dyn_loss,
957
- "norm_loss": mse_norm_loss,
958
- "boundary_loss": mse_boundary_loss,
959
- "initial_condition": mse_initial_condition,
960
- "observations": mse_observation_loss,
961
- "sobolev": mse_sobolev_loss,
962
- }
963
- )
650
+ total_loss = partial_mse + mse_initial_condition
964
651
 
965
- def tree_flatten(self):
966
- children = (self.norm_key, self.norm_samples, self.loss_weights)
967
- aux_data = {
968
- "u": self.u,
969
- "dynamic_loss": self.dynamic_loss,
970
- "derivative_keys": self.derivative_keys,
971
- "omega_boundary_fun": self.omega_boundary_fun,
972
- "omega_boundary_condition": self.omega_boundary_condition,
973
- "omega_boundary_dim": self.omega_boundary_dim,
974
- "initial_condition_fun": self.initial_condition_fun,
975
- "norm_borders": self.norm_borders,
976
- "sobolev_m": self.sobolev_m,
977
- "obs_slice": self.obs_slice,
652
+ return total_loss, {
653
+ **partial_mse_terms,
654
+ "initial_condition": mse_initial_condition,
978
655
  }
979
- return (children, aux_data)
980
-
981
- @classmethod
982
- def tree_unflatten(cls, aux_data, children):
983
- (norm_key, norm_samples, loss_weights) = children
984
- pls = cls(
985
- aux_data["u"],
986
- loss_weights,
987
- aux_data["dynamic_loss"],
988
- aux_data["derivative_keys"],
989
- aux_data["omega_boundary_fun"],
990
- aux_data["omega_boundary_condition"],
991
- aux_data["omega_boundary_dim"],
992
- aux_data["initial_condition_fun"],
993
- norm_key,
994
- aux_data["norm_borders"],
995
- norm_samples,
996
- aux_data["sobolev_m"],
997
- aux_data["obs_slice"],
998
- )
999
- return pls
1000
656
 
1001
657
 
1002
- @register_pytree_node_class
1003
- class SystemLossPDE:
1004
- """
658
+ class SystemLossPDE(eqx.Module):
659
+ r"""
1005
660
  Class to implement a system of PDEs.
1006
661
  The goal is to give maximum freedom to the user. The class is created with
1007
662
  a dict of dynamic loss, and dictionaries of all the objects that are used
@@ -1015,190 +670,186 @@ class SystemLossPDE:
1015
670
  Indeed, these dictionaries (except `dynamic_loss_dict`) are tied to one
1016
671
  solution.
1017
672
 
1018
- **Note:** SystemLossPDE is jittable. Hence it implements the tree_flatten() and
1019
- tree_unflatten methods.
673
+ Parameters
674
+ ----------
675
+ u_dict : Dict[str, eqx.Module]
676
+ dict of PINNs
677
+ loss_weights : LossWeightsPDEDict
678
+ A dictionary of LossWeightsODE
679
+ derivative_keys_dict : Dict[str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio], default=None
680
+ A dictionnary of DerivativeKeysPDEStatio or DerivativeKeysPDENonStatio
681
+ specifying what field of `params`
682
+ should be used during gradient computations for each of the terms of
683
+ the total loss, for each of the loss in the system. Default is
684
+ `"nn_params`" everywhere.
685
+ dynamic_loss_dict : Dict[str, PDEStatio | PDENonStatio]
686
+ A dict of dynamic part of the loss, basically the differential
687
+ operator $\mathcal{N}[u](t, x)$ or $\mathcal{N}[u](x)$.
688
+ key_dict : Dict[str, Key], default=None
689
+ A dictionary of JAX PRNG keys. The dictionary keys of key_dict must
690
+ match that of u_dict. See LossPDEStatio or LossPDENonStatio for
691
+ more details.
692
+ omega_boundary_fun_dict : Dict[str, Callable | Dict[str, Callable] | None], default=None
693
+ A dict of of function or of dict of functions or of None
694
+ (see doc for `omega_boundary_fun` in
695
+ LossPDEStatio or LossPDENonStatio). Default is None.
696
+ Must share the keys of `u_dict`.
697
+ omega_boundary_condition_dict : Dict[str, str | Dict[str, str] | None], default=None
698
+ A dict of strings or of dict of strings or of None
699
+ (see doc for `omega_boundary_condition_dict` in
700
+ LossPDEStatio or LossPDENonStatio). Default is None.
701
+ Must share the keys of `u_dict`
702
+ omega_boundary_dim_dict : Dict[str, slice | Dict[str, slice] | None], default=None
703
+ A dict of slices or of dict of slices or of None
704
+ (see doc for `omega_boundary_dim` in
705
+ LossPDEStatio or LossPDENonStatio). Default is None.
706
+ Must share the keys of `u_dict`
707
+ initial_condition_fun_dict : Dict[str, Callable | None], default=None
708
+ A dict of functions representing the temporal initial condition (None
709
+ value is possible). If None
710
+ (default) then no temporal boundary condition is applied
711
+ Must share the keys of `u_dict`
712
+ norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
713
+ A dict of fixed sample point in the space over which to compute the
714
+ normalization constant. Default is None
715
+ Must share the keys of `u_dict`
716
+ norm_int_length_dict : Dict[str, float | None] | None, default=None
717
+ A dict of Float. The domain area
718
+ (or interval length in 1D) upon which we perform the numerical
719
+ integration for each element of u_dict.
720
+ Default is None
721
+ Must share the keys of `u_dict`
722
+ obs_slice_dict : Dict[str, slice | None] | None, default=None
723
+ dict of obs_slice, with keys from `u_dict` to designate the
724
+ output(s) channels that are forced to observed values, for each
725
+ PINNs. Default is None. But if a value is given, all the entries of
726
+ `u_dict` must be represented here with default value `jnp.s_[...]`
727
+ if no particular slice is to be given
728
+
1020
729
  """
1021
730
 
1022
- def __init__(
1023
- self,
1024
- u_dict,
1025
- loss_weights,
1026
- dynamic_loss_dict,
1027
- nn_type_dict,
1028
- derivative_keys_dict=None,
1029
- omega_boundary_fun_dict=None,
1030
- omega_boundary_condition_dict=None,
1031
- omega_boundary_dim_dict=None,
1032
- initial_condition_fun_dict=None,
1033
- norm_key_dict=None,
1034
- norm_borders_dict=None,
1035
- norm_samples_dict=None,
1036
- sobolev_m_dict=None,
1037
- obs_slice_dict=None,
1038
- ):
1039
- r"""
1040
- Parameters
1041
- ----------
1042
- u_dict
1043
- A dict of PINNs
1044
- loss_weights
1045
- A dictionary of dictionaries with values used to
1046
- ponderate each term in the loss
1047
- function. The keys of the nested
1048
- dictionaries must share the keys of `u_dict`. Note that the values
1049
- at the leaf level can have jnp.arrays with the same dimension of
1050
- `u` which then ponderates each output of `u`
1051
- dynamic_loss_dict
1052
- A dict of dynamic part of the loss, basically the differential
1053
- operator :math:`\mathcal{N}[u](t)`.
1054
- nn_type_dict
1055
- A dict whose keys are that of u_dict whose value is either
1056
- `nn_statio` or `nn_nonstatio` which signifies either the PINN has a
1057
- time component in input or not.
1058
- derivative_keys_dict
1059
- A dict of derivative keys as defined in LossODE. The key of this
1060
- dict must be that of `dynamic_loss_dict` at least and specify how
1061
- to compute gradient for the `dyn_loss` loss term at least (see the
1062
- check at the beginning of the present `__init__` function.
1063
- Other keys of this dict might be that of `u_dict` to specify how to
1064
- compute gradients for all the different constraints. If those keys
1065
- are not specified then the default behaviour for `derivative_keys`
1066
- of LossODE is used
1067
- omega_boundary_fun_dict
1068
- A dict of dict of functions (see doc for `omega_boundary_fun` in
1069
- LossPDEStatio or LossPDENonStatio). Default is None.
1070
- Must share the keys of `u_dict`.
1071
- omega_boundary_condition_dict
1072
- A dict of dict of strings (see doc for
1073
- `omega_boundary_condition_dict` in
1074
- LossPDEStatio or LossPDENonStatio). Default is None.
1075
- Must share the keys of `u_dict`
1076
- omega_boundary_dim_dict
1077
- A dict of dict of slices (see doc for `omega_boundary_dim` in
1078
- LossPDEStatio or LossPDENonStatio). Default is None.
1079
- Must share the keys of `u_dict`
1080
- initial_condition_fun_dict
1081
- A dict of functions representing the temporal initial condition. If None
1082
- (default) then no temporal boundary condition is applied
1083
- Must share the keys of `u_dict`
1084
- norm_key_dict
1085
- A dict of Jax random keys to draw samples in for the Monte Carlo computation
1086
- of the normalization constant. Default is None
1087
- Must share the keys of `u_dict`
1088
- norm_borders_dict
1089
- A dict of tuples of (min, max) of the boundaray values of the space over which
1090
- to integrate in the computation of the normalization constant.
1091
- A list of tuple for higher dimensional problems. Default None.
1092
- Must share the keys of `u_dict`
1093
- norm_samples_dict
1094
- A dict of fixed sample point in the space over which to compute the
1095
- normalization constant. Default is None
1096
- Must share the keys of `u_dict`
1097
- sobolev_m
1098
- Default is None. A dictionary of integers, one per key which must
1099
- match `u_dict`.
1100
- It corresponds to the Sobolev regularization order as proposed in
1101
- *Convergence and error analysis of PINNs*,
1102
- Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
1103
- obs_slice_dict
1104
- dict of obs_slice, with keys from `u_dict` to designate the
1105
- output(s) channels that are forced to observed values, for each
1106
- PINNs. Default is None. But if a value is given, all the entries of
1107
- `u_dict` must be represented here with default value `jnp.s_[...]`
1108
- if no particular slice is to be given
1109
-
1110
-
1111
- Raises
1112
- ------
1113
- ValueError
1114
- if initial condition is not a dict of tuple
1115
- ValueError
1116
- if the dictionaries that should share the keys of u_dict do not
1117
- """
731
+ # NOTE static=True only for leaf attributes that are not valid JAX types
732
+ # (ie. jax.Array cannot be static) and that we do not expect to change
733
+ u_dict: Dict[str, eqx.Module]
734
+ dynamic_loss_dict: Dict[str, PDEStatio | PDENonStatio]
735
+ key_dict: Dict[str, Key] | None = eqx.field(kw_only=True, default=None)
736
+ derivative_keys_dict: Dict[
737
+ str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio | None
738
+ ] = eqx.field(kw_only=True, default=None)
739
+ omega_boundary_fun_dict: Dict[str, Callable | Dict[str, Callable] | None] | None = (
740
+ eqx.field(kw_only=True, default=None, static=True)
741
+ )
742
+ omega_boundary_condition_dict: Dict[str, str | Dict[str, str] | None] | None = (
743
+ eqx.field(kw_only=True, default=None, static=True)
744
+ )
745
+ omega_boundary_dim_dict: Dict[str, slice | Dict[str, slice] | None] | None = (
746
+ eqx.field(kw_only=True, default=None, static=True)
747
+ )
748
+ initial_condition_fun_dict: Dict[str, Callable | None] | None = eqx.field(
749
+ kw_only=True, default=None, static=True
750
+ )
751
+ norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
752
+ eqx.field(kw_only=True, default=None)
753
+ )
754
+ norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
755
+ kw_only=True, default=None
756
+ )
757
+ obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
758
+ kw_only=True, default=None, static=True
759
+ )
760
+
761
+ # For the user loss_weights are passed as a LossWeightsPDEDict (with internal
762
+ # dictionary having keys in u_dict and / or dynamic_loss_dict)
763
+ loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
764
+ kw_only=True, default=None
765
+ )
766
+
767
+ # following have init=False and are set in the __post_init__
768
+ u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
769
+ init=False
770
+ )
771
+ derivative_keys_u_dict: Dict[
772
+ str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
773
+ ] = eqx.field(init=False)
774
+ derivative_keys_dyn_loss_dict: Dict[
775
+ str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
776
+ ] = eqx.field(init=False)
777
+ u_dict_with_none: Dict[str, None] = eqx.field(init=False)
778
+ # internally the loss weights are handled with a dictionary
779
+ _loss_weights: Dict[str, dict] = eqx.field(init=False)
780
+
781
+ def __post_init__(self, loss_weights):
1118
782
  # a dictionary that will be useful at different places
1119
- self.u_dict_with_none = {k: None for k in u_dict.keys()}
783
+ self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
1120
784
  # First, for all the optional dict,
1121
785
  # if the user did not provide at all this optional argument,
1122
786
  # we make sure there is a null ponderating loss_weight and we
1123
787
  # create a dummy dict with the required keys and all the values to
1124
788
  # None
1125
- if omega_boundary_fun_dict is None:
789
+ if self.key_dict is None:
790
+ self.key_dict = self.u_dict_with_none
791
+ if self.omega_boundary_fun_dict is None:
1126
792
  self.omega_boundary_fun_dict = self.u_dict_with_none
1127
- else:
1128
- self.omega_boundary_fun_dict = omega_boundary_fun_dict
1129
- if omega_boundary_condition_dict is None:
793
+ if self.omega_boundary_condition_dict is None:
1130
794
  self.omega_boundary_condition_dict = self.u_dict_with_none
1131
- else:
1132
- self.omega_boundary_condition_dict = omega_boundary_condition_dict
1133
- if omega_boundary_dim_dict is None:
795
+ if self.omega_boundary_dim_dict is None:
1134
796
  self.omega_boundary_dim_dict = self.u_dict_with_none
1135
- else:
1136
- self.omega_boundary_dim_dict = omega_boundary_dim_dict
1137
- if initial_condition_fun_dict is None:
797
+ if self.initial_condition_fun_dict is None:
1138
798
  self.initial_condition_fun_dict = self.u_dict_with_none
1139
- else:
1140
- self.initial_condition_fun_dict = initial_condition_fun_dict
1141
- if norm_key_dict is None:
1142
- self.norm_key_dict = self.u_dict_with_none
1143
- else:
1144
- self.norm_key_dict = norm_key_dict
1145
- if norm_borders_dict is None:
1146
- self.norm_borders_dict = self.u_dict_with_none
1147
- else:
1148
- self.norm_borders_dict = norm_borders_dict
1149
- if norm_samples_dict is None:
799
+ if self.norm_samples_dict is None:
1150
800
  self.norm_samples_dict = self.u_dict_with_none
1151
- else:
1152
- self.norm_samples_dict = norm_samples_dict
1153
- if sobolev_m_dict is None:
1154
- self.sobolev_m_dict = self.u_dict_with_none
1155
- else:
1156
- self.sobolev_m_dict = sobolev_m_dict
1157
- if obs_slice_dict is None:
1158
- self.obs_slice_dict = {k: jnp.s_[...] for k in u_dict.keys()}
1159
- else:
1160
- self.obs_slice_dict = obs_slice_dict
1161
- if u_dict.keys() != obs_slice_dict.keys():
801
+ if self.norm_int_length_dict is None:
802
+ self.norm_int_length_dict = self.u_dict_with_none
803
+ if self.obs_slice_dict is None:
804
+ self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
805
+ if self.u_dict.keys() != self.obs_slice_dict.keys():
1162
806
  raise ValueError("obs_slice_dict should have same keys as u_dict")
1163
- if derivative_keys_dict is None:
807
+ if self.derivative_keys_dict is None:
1164
808
  self.derivative_keys_dict = {
1165
809
  k: None
1166
- for k in set(list(dynamic_loss_dict.keys()) + list(u_dict.keys()))
810
+ for k in set(
811
+ list(self.dynamic_loss_dict.keys()) + list(self.u_dict.keys())
812
+ )
1167
813
  }
1168
814
  # set() because we can have duplicate entries and in this case we
1169
815
  # say it corresponds to the same derivative_keys_dict entry
1170
- else:
1171
- self.derivative_keys_dict = derivative_keys_dict
816
+ # we need both because the constraints (all but dyn_loss) will be
817
+ # done by iterating on u_dict while the dyn_loss will be by
818
+ # iterating on dynamic_loss_dict. So each time we will require dome
819
+ # derivative_keys_dict
820
+
1172
821
  # but then if the user did not provide anything, we must at least have
1173
822
  # a default value for the dynamic_loss_dict keys entries in
1174
823
  # self.derivative_keys_dict since the computation of dynamic losses is
1175
- # made without create a lossODE object that would provide the
824
+ # made without create a loss object that would provide the
1176
825
  # default values
1177
- for k in dynamic_loss_dict.keys():
826
+ for k in self.dynamic_loss_dict.keys():
1178
827
  if self.derivative_keys_dict[k] is None:
1179
- self.derivative_keys_dict[k] = {"dyn_loss": ["nn_params"]}
828
+ try:
829
+ if self.u_dict[k].eq_type == "statio_PDE":
830
+ self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
831
+ else:
832
+ self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
833
+ except KeyError: # We are in a key that is not in u_dict but in
834
+ # dynamic_loss_dict
835
+ if isinstance(self.dynamic_loss_dict[k], PDEStatio):
836
+ self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
837
+ else:
838
+ self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
1180
839
 
1181
840
  # Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
1182
841
  if (
1183
- u_dict.keys() != nn_type_dict.keys()
1184
- or u_dict.keys() != self.omega_boundary_fun_dict.keys()
1185
- or u_dict.keys() != self.omega_boundary_condition_dict.keys()
1186
- or u_dict.keys() != self.omega_boundary_dim_dict.keys()
1187
- or u_dict.keys() != self.initial_condition_fun_dict.keys()
1188
- or u_dict.keys() != self.norm_key_dict.keys()
1189
- or u_dict.keys() != self.norm_borders_dict.keys()
1190
- or u_dict.keys() != self.norm_samples_dict.keys()
1191
- or u_dict.keys() != self.sobolev_m_dict.keys()
842
+ self.u_dict.keys() != self.key_dict.keys()
843
+ or self.u_dict.keys() != self.omega_boundary_fun_dict.keys()
844
+ or self.u_dict.keys() != self.omega_boundary_condition_dict.keys()
845
+ or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
846
+ or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
847
+ or self.u_dict.keys() != self.norm_samples_dict.keys()
848
+ or self.u_dict.keys() != self.norm_int_length_dict.keys()
1192
849
  ):
1193
850
  raise ValueError("All the dicts concerning the PINNs should have same keys")
1194
851
 
1195
- self.dynamic_loss_dict = dynamic_loss_dict
1196
- self.u_dict = u_dict
1197
- # TODO nn_type should become a class attribute now that we have PINN
1198
- # class and SPINNs class
1199
- self.nn_type_dict = nn_type_dict
1200
-
1201
- self.loss_weights = loss_weights # This calls the setter
852
+ self._loss_weights = self.set_loss_weights(loss_weights)
1202
853
 
1203
854
  # Third, in order not to benefit from LossPDEStatio and
1204
855
  # LossPDENonStatio and in order to factorize code, we create internally
@@ -1206,52 +857,51 @@ class SystemLossPDE:
1206
857
  # We will not use the dynamic loss term
1207
858
  self.u_constraints_dict = {}
1208
859
  for i in self.u_dict.keys():
1209
- if self.nn_type_dict[i] == "nn_statio":
860
+ if self.u_dict[i].eq_type == "statio_PDE":
1210
861
  self.u_constraints_dict[i] = LossPDEStatio(
1211
- u=u_dict[i],
1212
- loss_weights={
1213
- "dyn_loss": 0.0,
1214
- "norm_loss": 1.0,
1215
- "boundary_loss": 1.0,
1216
- "observations": 1.0,
1217
- "sobolev": 1.0,
1218
- },
862
+ u=self.u_dict[i],
863
+ loss_weights=LossWeightsPDENonStatio(
864
+ dyn_loss=0.0,
865
+ norm_loss=1.0,
866
+ boundary_loss=1.0,
867
+ observations=1.0,
868
+ initial_condition=1.0,
869
+ ),
1219
870
  dynamic_loss=None,
871
+ key=self.key_dict[i],
1220
872
  derivative_keys=self.derivative_keys_dict[i],
1221
873
  omega_boundary_fun=self.omega_boundary_fun_dict[i],
1222
874
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
1223
875
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
1224
- norm_key=self.norm_key_dict[i],
1225
- norm_borders=self.norm_borders_dict[i],
1226
876
  norm_samples=self.norm_samples_dict[i],
1227
- sobolev_m=self.sobolev_m_dict[i],
877
+ norm_int_length=self.norm_int_length_dict[i],
1228
878
  obs_slice=self.obs_slice_dict[i],
1229
879
  )
1230
- elif self.nn_type_dict[i] == "nn_nonstatio":
880
+ elif self.u_dict[i].eq_type == "nonstatio_PDE":
1231
881
  self.u_constraints_dict[i] = LossPDENonStatio(
1232
- u=u_dict[i],
1233
- loss_weights={
1234
- "dyn_loss": 0.0,
1235
- "norm_loss": 1.0,
1236
- "boundary_loss": 1.0,
1237
- "observations": 1.0,
1238
- "initial_condition": 1.0,
1239
- "sobolev": 1.0,
1240
- },
882
+ u=self.u_dict[i],
883
+ loss_weights=LossWeightsPDENonStatio(
884
+ dyn_loss=0.0,
885
+ norm_loss=1.0,
886
+ boundary_loss=1.0,
887
+ observations=1.0,
888
+ initial_condition=1.0,
889
+ ),
1241
890
  dynamic_loss=None,
891
+ key=self.key_dict[i],
1242
892
  derivative_keys=self.derivative_keys_dict[i],
1243
893
  omega_boundary_fun=self.omega_boundary_fun_dict[i],
1244
894
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
1245
895
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
1246
896
  initial_condition_fun=self.initial_condition_fun_dict[i],
1247
- norm_key=self.norm_key_dict[i],
1248
- norm_borders=self.norm_borders_dict[i],
1249
897
  norm_samples=self.norm_samples_dict[i],
1250
- sobolev_m=self.sobolev_m_dict[i],
898
+ norm_int_length=self.norm_int_length_dict[i],
899
+ obs_slice=self.obs_slice_dict[i],
1251
900
  )
1252
901
  else:
1253
902
  raise ValueError(
1254
- f"Wrong value for nn_type_dict[i], got {nn_type_dict[i]}"
903
+ "Wrong value for self.u_dict[i].eq_type[i], "
904
+ f"got {self.u_dict[i].eq_type[i]}"
1255
905
  )
1256
906
 
1257
907
  # for convenience in the tree_map of evaluate,
@@ -1267,34 +917,38 @@ class SystemLossPDE:
1267
917
 
1268
918
  # also make sure we only have PINNs or SPINNs
1269
919
  if not (
1270
- all(isinstance(value, PINN) for value in u_dict.values())
1271
- or all(isinstance(value, SPINN) for value in u_dict.values())
920
+ all(isinstance(value, PINN) for value in self.u_dict.values())
921
+ or all(isinstance(value, SPINN) for value in self.u_dict.values())
1272
922
  ):
1273
923
  raise ValueError(
1274
924
  "We only accept dictionary of PINNs or dictionary of SPINNs"
1275
925
  )
1276
926
 
1277
- @property
1278
- def loss_weights(self):
1279
- return self._loss_weights
1280
-
1281
- @loss_weights.setter
1282
- def loss_weights(self, value):
1283
- self._loss_weights = {}
1284
- for k, v in value.items():
927
+ def set_loss_weights(
928
+ self, loss_weights_init: LossWeightsPDEDict
929
+ ) -> dict[str, dict]:
930
+ """
931
+ This rather complex function enables the user to specify a simple
932
+ loss_weights=LossWeightsPDEDict(dyn_loss=1., initial_condition=Tmax)
933
+ for ponderating values being applied to all the equations of the
934
+ system... So all the transformations are handled here
935
+ """
936
+ _loss_weights = {}
937
+ for k in fields(loss_weights_init):
938
+ v = getattr(loss_weights_init, k.name)
1285
939
  if isinstance(v, dict):
1286
- for kk, vv in v.items():
940
+ for vv in v.keys():
1287
941
  if not isinstance(vv, (int, float)) and not (
1288
- isinstance(vv, jnp.ndarray)
942
+ isinstance(vv, Array)
1289
943
  and ((vv.shape == (1,) or len(vv.shape) == 0))
1290
944
  ):
1291
945
  # TODO improve that
1292
946
  raise ValueError(
1293
947
  f"loss values cannot be vectorial here, got {vv}"
1294
948
  )
1295
- if k == "dyn_loss":
949
+ if k.name == "dyn_loss":
1296
950
  if v.keys() == self.dynamic_loss_dict.keys():
1297
- self._loss_weights[k] = v
951
+ _loss_weights[k.name] = v
1298
952
  else:
1299
953
  raise ValueError(
1300
954
  "Keys in nested dictionary of loss_weights"
@@ -1302,51 +956,36 @@ class SystemLossPDE:
1302
956
  )
1303
957
  else:
1304
958
  if v.keys() == self.u_dict.keys():
1305
- self._loss_weights[k] = v
959
+ _loss_weights[k.name] = v
1306
960
  else:
1307
961
  raise ValueError(
1308
962
  "Keys in nested dictionary of loss_weights"
1309
963
  " do not match u_dict keys"
1310
964
  )
965
+ if v is None:
966
+ _loss_weights[k.name] = {kk: 0 for kk in self.u_dict.keys()}
1311
967
  else:
1312
968
  if not isinstance(v, (int, float)) and not (
1313
- isinstance(v, jnp.ndarray)
1314
- and ((v.shape == (1,) or len(v.shape) == 0))
969
+ isinstance(v, Array) and ((v.shape == (1,) or len(v.shape) == 0))
1315
970
  ):
1316
971
  # TODO improve that
1317
972
  raise ValueError(f"loss values cannot be vectorial here, got {v}")
1318
- if k == "dyn_loss":
1319
- self._loss_weights[k] = {
973
+ if k.name == "dyn_loss":
974
+ _loss_weights[k.name] = {
1320
975
  kk: v for kk in self.dynamic_loss_dict.keys()
1321
976
  }
1322
977
  else:
1323
- self._loss_weights[k] = {kk: v for kk in self.u_dict.keys()}
1324
- # Some special checks below
1325
- if all(v is None for k, v in self.sobolev_m_dict.items()):
1326
- self._loss_weights["sobolev"] = {k: 0 for k in self.u_dict.keys()}
1327
- if "observations" not in value.keys():
1328
- self._loss_weights["observations"] = {k: 0 for k in self.u_dict.keys()}
1329
- if all(v is None for k, v in self.omega_boundary_fun_dict.items()) or all(
1330
- v is None for k, v in self.omega_boundary_condition_dict.items()
1331
- ):
1332
- self._loss_weights["boundary_loss"] = {k: 0 for k in self.u_dict.keys()}
1333
- if (
1334
- all(v is None for k, v in self.norm_key_dict.items())
1335
- or all(v is None for k, v in self.norm_borders_dict.items())
1336
- or all(v is None for k, v in self.norm_samples_dict.items())
1337
- ):
1338
- self._loss_weights["norm_loss"] = {k: 0 for k in self.u_dict.keys()}
1339
- if all(v is None for k, v in self.initial_condition_fun_dict.items()):
1340
- self._loss_weights["initial_condition"] = {k: 0 for k in self.u_dict.keys()}
978
+ _loss_weights[k.name] = {kk: v for kk in self.u_dict.keys()}
979
+ return _loss_weights
1341
980
 
1342
981
  def __call__(self, *args, **kwargs):
1343
982
  return self.evaluate(*args, **kwargs)
1344
983
 
1345
984
  def evaluate(
1346
985
  self,
1347
- params_dict,
1348
- batch,
1349
- ):
986
+ params_dict: ParamsDict,
987
+ batch: PDEStatioBatch | PDENonStatioBatch,
988
+ ) -> tuple[Float[Array, "1"], dict[str, float]]:
1350
989
  """
1351
990
  Evaluate the loss function at a batch of points for given parameters.
1352
991
 
@@ -1354,12 +993,8 @@ class SystemLossPDE:
1354
993
  Parameters
1355
994
  ---------
1356
995
  params_dict
1357
- A dictionary of dictionaries of parameters of the model.
1358
- Typically, it is a dictionary of dictionaries of
1359
- dictionaries: `eq_params` and `nn_params``, respectively the
1360
- differential equation parameters and the neural network parameter
996
+ Parameters at which the losses of the system are evaluated
1361
997
  batch
1362
- A PDEStatioBatch or PDENonStatioBatch object.
1363
998
  Such named tuples are composed of batch of points in the
1364
999
  domain, a batch of points in the domain
1365
1000
  border, (a batch of time points a for PDENonStatioBatch) and an
@@ -1367,32 +1002,17 @@ class SystemLossPDE:
1367
1002
  and an optional additional batch of observed
1368
1003
  inputs/outputs/parameters
1369
1004
  """
1370
- if self.u_dict.keys() != params_dict["nn_params"].keys():
1005
+ if self.u_dict.keys() != params_dict.nn_params.keys():
1371
1006
  raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1372
1007
 
1373
1008
  if isinstance(batch, PDEStatioBatch):
1374
1009
  omega_batch, _ = batch.inside_batch, batch.border_batch
1375
- n = omega_batch.shape[0]
1376
1010
  vmap_in_axes_x_or_x_t = (0,)
1377
1011
 
1378
1012
  batches = (omega_batch,)
1379
1013
  elif isinstance(batch, PDENonStatioBatch):
1380
- omega_batch, _, times_batch = (
1381
- batch.inside_batch,
1382
- batch.border_batch,
1383
- batch.temporal_batch,
1384
- )
1385
- n = omega_batch.shape[0]
1386
- nt = times_batch.shape[0]
1387
- times_batch = times_batch.reshape(nt, 1)
1388
-
1389
- def rep_times(k):
1390
- return jnp.repeat(times_batch, k, axis=0)
1391
-
1392
- # Moreover...
1393
- if isinstance(list(self.u_dict.values())[0], PINN):
1394
- omega_batch = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
1395
- times_batch = rep_times(n) # it is repeated
1014
+ times_batch = batch.times_x_inside_batch[:, 0:1]
1015
+ omega_batch = batch.times_x_inside_batch[:, 1:]
1396
1016
 
1397
1017
  batches = (omega_batch, times_batch)
1398
1018
  vmap_in_axes_x_or_x_t = (0, 0)
@@ -1405,9 +1025,10 @@ class SystemLossPDE:
1405
1025
  if batch.param_batch_dict is not None:
1406
1026
  eq_params_batch_dict = batch.param_batch_dict
1407
1027
 
1028
+ # TODO
1408
1029
  # feed the eq_params with the batch
1409
1030
  for k in eq_params_batch_dict.keys():
1410
- params_dict["eq_params"][k] = eq_params_batch_dict[k]
1031
+ params_dict.eq_params[k] = eq_params_batch_dict[k]
1411
1032
 
1412
1033
  vmap_in_axes_params = _get_vmap_in_axes_params(
1413
1034
  batch.param_batch_dict, params_dict
@@ -1415,12 +1036,11 @@ class SystemLossPDE:
1415
1036
 
1416
1037
  def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
1417
1038
  """The function used in tree_map"""
1418
- params_dict_ = _set_derivatives(params_dict, "dyn_loss", derivative_key)
1419
1039
  return dynamic_loss_apply(
1420
1040
  dyn_loss.evaluate,
1421
1041
  self.u_dict,
1422
1042
  batches,
1423
- params_dict_,
1043
+ _set_derivatives(params_dict, derivative_key.dyn_loss),
1424
1044
  vmap_in_axes_x_or_x_t + vmap_in_axes_params,
1425
1045
  loss_weight,
1426
1046
  u_type=type(list(self.u_dict.values())[0]),
@@ -1431,6 +1051,13 @@ class SystemLossPDE:
1431
1051
  self.dynamic_loss_dict,
1432
1052
  self.derivative_keys_dyn_loss_dict,
1433
1053
  self._loss_weights["dyn_loss"],
1054
+ is_leaf=lambda x: isinstance(
1055
+ x, (PDEStatio, PDENonStatio)
1056
+ ), # before when dynamic losses
1057
+ # where plain (unregister pytree) node classes, we could not traverse
1058
+ # this level. Now that dynamic losses are eqx.Module they can be
1059
+ # traversed by tree map recursion. Hence we need to specify to that
1060
+ # we want to stop at this level
1434
1061
  )
1435
1062
  mse_dyn_loss = jax.tree_util.tree_reduce(
1436
1063
  lambda x, y: x + y, jax.tree_util.tree_leaves(dyn_loss_mse_dict)
@@ -1445,11 +1072,10 @@ class SystemLossPDE:
1445
1072
  "boundary_loss": "*",
1446
1073
  "observations": "*",
1447
1074
  "initial_condition": "*",
1448
- "sobolev": "*",
1449
1075
  }
1450
1076
  # we need to do the following for the tree_mapping to work
1451
1077
  if batch.obs_batch_dict is None:
1452
- batch = batch._replace(obs_batch_dict=self.u_dict_with_none)
1078
+ batch = append_obs_batch(batch, self.u_dict_with_none)
1453
1079
  total_loss, res_dict = constraints_system_loss_apply(
1454
1080
  self.u_constraints_dict,
1455
1081
  batch,
@@ -1462,41 +1088,3 @@ class SystemLossPDE:
1462
1088
  total_loss += mse_dyn_loss
1463
1089
  res_dict["dyn_loss"] += mse_dyn_loss
1464
1090
  return total_loss, res_dict
1465
-
1466
- def tree_flatten(self):
1467
- children = (
1468
- self.norm_key_dict,
1469
- self.norm_samples_dict,
1470
- self.initial_condition_fun_dict,
1471
- self._loss_weights,
1472
- )
1473
- aux_data = {
1474
- "u_dict": self.u_dict,
1475
- "dynamic_loss_dict": self.dynamic_loss_dict,
1476
- "norm_borders_dict": self.norm_borders_dict,
1477
- "omega_boundary_fun_dict": self.omega_boundary_fun_dict,
1478
- "omega_boundary_condition_dict": self.omega_boundary_condition_dict,
1479
- "nn_type_dict": self.nn_type_dict,
1480
- "sobolev_m_dict": self.sobolev_m_dict,
1481
- "derivative_keys_dict": self.derivative_keys_dict,
1482
- "obs_slice_dict": self.obs_slice_dict,
1483
- }
1484
- return (children, aux_data)
1485
-
1486
- @classmethod
1487
- def tree_unflatten(cls, aux_data, children):
1488
- (
1489
- norm_key_dict,
1490
- norm_samples_dict,
1491
- initial_condition_fun_dict,
1492
- loss_weights,
1493
- ) = children
1494
- loss_ode = cls(
1495
- loss_weights=loss_weights,
1496
- norm_key_dict=norm_key_dict,
1497
- norm_samples_dict=norm_samples_dict,
1498
- initial_condition_fun_dict=initial_condition_fun_dict,
1499
- **aux_data,
1500
- )
1501
-
1502
- return loss_ode