jinns 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +292 -309
  10. jinns/loss/_LossPDE.py +625 -1010
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +95 -44
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/METADATA +4 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.9.0.dist-info/RECORD +0 -36
  41. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,12 +2,21 @@
2
2
  Implements several dynamic losses
3
3
  """
4
4
 
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Dict
10
+ from jaxtyping import Float
5
11
  import jax
6
- from jax import grad, jacrev
12
+ from jax import grad
7
13
  import jax.numpy as jnp
8
- from jinns.utils._utils import _get_grid, _extract_nn_params
14
+ import equinox as eqx
15
+
9
16
  from jinns.utils._pinn import PINN
10
17
  from jinns.utils._spinn import SPINN
18
+
19
+ from jinns.utils._utils import _get_grid
11
20
  from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
12
21
  from jinns.loss._operators import (
13
22
  _laplacian_rev,
@@ -19,52 +28,44 @@ from jinns.loss._operators import (
19
28
  _u_dot_nabla_times_u_fwd,
20
29
  )
21
30
 
31
+ from jaxtyping import Array, Float
32
+
33
+ if TYPE_CHECKING:
34
+ from jinns.parameters import Params, ParamsDict
35
+
22
36
 
23
37
  class FisherKPP(PDENonStatio):
24
38
  r"""
25
- Return the Fisher KPP dynamic loss term. Dimension of :math:`x` can be
39
+ Return the Fisher KPP dynamic loss term. Dimension of $x$ can be
26
40
  arbitrary
27
41
 
28
- .. math::
29
- \frac{\partial}{\partial t} u(t,x)=D\Delta u(t,x) + u(t,x)(r(x) - \gamma(x)u(t,x))
30
-
42
+ $$
43
+ \frac{\partial}{\partial t} u(t,x)=D\Delta u(t,x) + u(t,x)(r(x) - \gamma(x)u(t,x))
44
+ $$
31
45
  """
32
46
 
33
- def __init__(self, Tmax=1, eq_params_heterogeneity=None):
34
- """
35
- Parameters
36
- ----------
37
- Tmax
38
- Tmax needs to be given when the PINN time input is normalized in
39
- [0, 1], ie. we have performed renormalization of the differential
40
- equation
41
- eq_params_heterogeneity
42
- Default None. A dict with the keys being the same as in eq_params
43
- and the value being `time`, `space`, `both` or None which corresponds to
44
- the heterogeneity of a given parameter. A value can be missing, in
45
- this case there is no heterogeneity (=None). If
46
- eq_params_heterogeneity is None this means there is no
47
- heterogeneity for no parameters.
48
- """
49
- super().__init__(Tmax, eq_params_heterogeneity)
50
-
51
- @PDENonStatio.evaluate_heterogeneous_parameters
52
- def evaluate(self, t, x, u, params):
47
+ def equation(
48
+ self,
49
+ t: Float[Array, "1"],
50
+ x: Float[Array, "dim"],
51
+ u: eqx.Module,
52
+ params: Params,
53
+ ) -> Float[Array, "1"]:
53
54
  r"""
54
- Evaluate the dynamic loss at :math:`(t,x)`.
55
+ Evaluate the dynamic loss at $(t,x)$.
55
56
 
56
57
  Parameters
57
58
  ---------
58
59
  t
59
- A time point
60
+ A time point.
60
61
  x
61
- A point in :math:`\Omega`
62
+ A point in $\Omega$.
62
63
  u
63
64
  The PINN
64
65
  params
65
66
  The dictionary of parameters of the model.
66
67
  Typically, it is a dictionary of
67
- dictionaries: `eq_params` and `nn_params``, respectively the
68
+ dictionaries: `eq_params` and `nn_params`, respectively the
68
69
  differential equation parameters and the neural network parameter
69
70
  """
70
71
  if isinstance(u, PINN):
@@ -76,12 +77,9 @@ class FisherKPP(PDENonStatio):
76
77
  lap = _laplacian_rev(t, x, u, params)[..., None]
77
78
 
78
79
  return du_dt + self.Tmax * (
79
- -params["eq_params"]["D"] * lap
80
+ -params.eq_params["D"] * lap
80
81
  - u(t, x, params)
81
- * (
82
- params["eq_params"]["r"]
83
- - params["eq_params"]["g"] * u(t, x, params)
84
- )
82
+ * (params.eq_params["r"] - params.eq_params["g"] * u(t, x, params))
85
83
  )
86
84
  if isinstance(u, SPINN):
87
85
  u_tx, du_dt = jax.jvp(
@@ -91,49 +89,129 @@ class FisherKPP(PDENonStatio):
91
89
  )
92
90
  lap = _laplacian_fwd(t, x, u, params)[..., None]
93
91
  return du_dt + self.Tmax * (
94
- -params["eq_params"]["D"] * lap
92
+ -params.eq_params["D"] * lap
95
93
  - u_tx
96
- * (
97
- params["eq_params"]["r"][..., None]
98
- - params["eq_params"]["g"] * u_tx
99
- )
94
+ * (params.eq_params["r"][..., None] - params.eq_params["g"] * u_tx)
100
95
  )
101
96
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
102
97
 
103
98
 
104
- class BurgerEquation(PDENonStatio):
99
+ class GeneralizedLotkaVolterra(ODE):
105
100
  r"""
106
- Return the Burger dynamic loss term (in 1 space dimension):
107
-
108
- .. math::
109
- \frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
110
- u(t,x) - \theta \frac{\partial^2}{\partial x^2} u(t,x) = 0
101
+ Return a dynamic loss from an equation of a Generalized Lotka Volterra
102
+ system. Say we implement the equation for population $i$
111
103
 
104
+ $$
105
+ \frac{\partial}{\partial t}u_i(t) = r_iu_i(t) - \sum_{j\neq i}\alpha_{ij}u_j(t)
106
+ -\alpha_{i,i}u_i(t) + c_iu_i(t) + \sum_{j \neq i} c_ju_j(t)
107
+ $$
108
+ with $r_i$ the growth rate parameter, $c_i$ the carrying
109
+ capacities and $\alpha_{ij}$ the interaction terms.
110
+
111
+ Parameters
112
+ ----------
113
+ key_main
114
+ The dictionary key (in the dictionaries `u` and `params` that
115
+ are arguments of the `evaluate` function) of the main population
116
+ $i$ of the particular equation of the system implemented
117
+ by this dynamic loss
118
+ keys_other
119
+ The list of dictionary keys (in the dictionaries `u` and `params` that
120
+ are arguments of the `evaluate` function) of the other
121
+ populations that appear in the equation of the system implemented
122
+ by this dynamic loss
123
+ Tmax
124
+ Tmax needs to be given when the PINN time input is normalized in
125
+ $[0, 1]$, ie. we have performed renormalization of the differential
126
+ equation.
127
+ eq_params_heterogeneity
128
+ Default None. A dict with the keys being the same as in eq_params
129
+ and the value being `time`, `space`, `both` or None which corresponds to
130
+ the heterogeneity of a given parameter. A value can be missing, in
131
+ this case there is no heterogeneity (=None). If
132
+ eq_params_heterogeneity is None this means there is no
133
+ heterogeneity for no parameters.
112
134
  """
113
135
 
114
- def __init__(
136
+ # they should be static because they are list of strings
137
+ key_main: list[str] = eqx.field(static=True)
138
+ keys_other: list[str] = eqx.field(static=True)
139
+
140
+ def equation(
115
141
  self,
116
- Tmax=1,
117
- eq_params_heterogeneity=None,
118
- ):
142
+ t: Float[Array, "1"],
143
+ u_dict: Dict[str, eqx.Module],
144
+ params_dict: ParamsDict,
145
+ ) -> Float[Array, "1"]:
119
146
  """
147
+ Evaluate the dynamic loss at `t`.
148
+ For stability we implement the dynamic loss in log space.
149
+
120
150
  Parameters
121
- ----------
122
- Tmax
123
- Tmax needs to be given when the PINN time input is normalized in
124
- [0, 1], ie. we have performed renormalization of the differential
125
- equation
126
- eq_params_heterogeneity
127
- Default None. A dict with the keys being the same as in eq_params
128
- and the value being `time`, `space`, `both` or None which corresponds to
129
- the heterogeneity of a given parameter. A value can be missing, in
130
- this case there is no heterogeneity (=None). If
131
- eq_params_heterogeneity is None this means there is no
132
- heterogeneity for no parameters.
151
+ ---------
152
+ t
153
+ A time point
154
+ u_dict
155
+ A dictionary of PINNS. Must have the same keys as `params_dict`
156
+ params_dict
157
+ The dictionary of dictionaries of parameters of the model. Keys at
158
+ top level are "nn_params" and "eq_params"
133
159
  """
134
- super().__init__(Tmax, eq_params_heterogeneity)
160
+ params_main = params_dict.extract_params(self.key_main)
161
+
162
+ u = u_dict[self.key_main]
163
+ # need to index with [0] since u output is nec (1,)
164
+ du_dt = grad(lambda t: jnp.log(u(t, params_main)[0]), 0)(t)
165
+ carrying_term = params_main.eq_params["carrying_capacity"] * u(t, params_main)
166
+ # NOTE the following assumes interaction term with oneself is at idx 0
167
+ interaction_terms = params_main.eq_params["interactions"][0] * u(t, params_main)
168
+
169
+ # TODO write this for loop with tree_util functions?
170
+ for i, k in enumerate(self.keys_other):
171
+ params_k = params_dict.extract_params(k)
172
+ carrying_term += params_main.eq_params["carrying_capacity"] * u_dict[k](
173
+ t, params_k
174
+ )
175
+ interaction_terms += params_main.eq_params["interactions"][i + 1] * u_dict[
176
+ k
177
+ ](t, params_k)
178
+
179
+ return du_dt + self.Tmax * (
180
+ -params_main.eq_params["growth_rate"] - interaction_terms + carrying_term
181
+ )
182
+
183
+
184
+ class BurgerEquation(PDENonStatio):
185
+ r"""
186
+ Return the Burger dynamic loss term (in 1 space dimension):
135
187
 
136
- def evaluate(self, t, x, u, params):
188
+ $$
189
+ \frac{\partial}{\partial t} u(t,x) + u(t,x)\frac{\partial}{\partial x}
190
+ u(t,x) - \theta \frac{\partial^2}{\partial x^2} u(t,x) = 0
191
+ $$
192
+
193
+ Parameters
194
+ ----------
195
+ Tmax
196
+ Tmax needs to be given when the PINN time input is normalized in
197
+ [0, 1], ie. we have performed renormalization of the differential
198
+ equation
199
+ eq_params_heterogeneity
200
+ Default None. A dict with the keys being the same as in eq_params
201
+ and the value being `time`, `space`, `both` or None which corresponds to
202
+ the heterogeneity of a given parameter. A value can be missing, in
203
+ this case there is no heterogeneity (=None). If
204
+ eq_params_heterogeneity is None this means there is no
205
+ heterogeneity for no parameters.
206
+ """
207
+
208
+ def equation(
209
+ self,
210
+ t: Float[Array, "1"],
211
+ x: Float[Array, "dim"],
212
+ u: eqx.Module,
213
+ params: Params,
214
+ ) -> Float[Array, "1"]:
137
215
  r"""
138
216
  Evaluate the dynamic loss at :math:`(t,x)`.
139
217
 
@@ -142,14 +220,11 @@ class BurgerEquation(PDENonStatio):
142
220
  t
143
221
  A time point
144
222
  x
145
- A point in :math:`\Omega`
223
+ A point in $\Omega$
146
224
  u
147
225
  The PINN
148
226
  params
149
227
  The dictionary of parameters of the model.
150
- Typically, it is a dictionary of
151
- dictionaries: `eq_params` and `nn_params``, respectively the
152
- differential equation parameters and the neural network parameter
153
228
  """
154
229
  if isinstance(u, PINN):
155
230
  # Note that the last dim of u is nec. 1
@@ -162,8 +237,7 @@ class BurgerEquation(PDENonStatio):
162
237
  )
163
238
 
164
239
  return du_dt(t, x) + self.Tmax * (
165
- u(t, x, params) * du_dx(t, x)
166
- - params["eq_params"]["nu"] * d2u_dx2(t, x)
240
+ u(t, x, params) * du_dx(t, x) - params.eq_params["nu"] * d2u_dx2(t, x)
167
241
  )
168
242
 
169
243
  if isinstance(u, SPINN):
@@ -181,161 +255,68 @@ class BurgerEquation(PDENonStatio):
181
255
  )[1]
182
256
  du_dx, d2u_dx2 = jax.jvp(du_dx_fun, (x,), (jnp.ones_like(x),))
183
257
  # Note that ones_like(x) works because x is Bx1 !
184
- return du_dt + self.Tmax * (
185
- u_tx * du_dx - params["eq_params"]["nu"] * d2u_dx2
186
- )
258
+ return du_dt + self.Tmax * (u_tx * du_dx - params.eq_params["nu"] * d2u_dx2)
187
259
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
188
260
 
189
261
 
190
- class GeneralizedLotkaVolterra(ODE):
191
- r"""
192
- Return a dynamic loss from an equation of a Generalized Lotka Volterra
193
- system. Say we implement the equation for population :math:`i`:
194
-
195
- .. math::
196
- \frac{\partial}{\partial t}u_i(t) = r_iu_i(t) - \sum_{j\neq i}\alpha_{ij}u_j(t)
197
- -\alpha_{i,i}u_i(t) + c_iu_i(t) + \sum_{j \neq i} c_ju_j(t)
198
-
199
- with :math:`r_i` the growth rate parameter, :math:`c_i` the carrying
200
- capacities and :math:`\alpha_{ij}` the interaction terms.
201
-
202
- """
203
-
204
- def __init__(
205
- self,
206
- key_main,
207
- keys_other,
208
- Tmax=1,
209
- eq_params_heterogeneity=None,
210
- ):
211
- """
212
- Parameters
213
- ----------
214
- key_main
215
- The dictionary key (in the dictionaries ``u`` and ``params`` that
216
- are arguments of the ``evaluate`` function) of the main population
217
- :math:`i` of the particular equation of the system implemented
218
- by this dynamic loss
219
- keys_other
220
- The list of dictionary keys (in the dictionaries ``u`` and ``params`` that
221
- are arguments of the ``evaluate`` function) of the other
222
- populations that appear in the equation of the system implemented
223
- by this dynamic loss
224
- Tmax
225
- Tmax needs to be given when the PINN time input is normalized in
226
- [0, 1], ie. we have performed renormalization of the differential
227
- equation
228
- eq_params_heterogeneity
229
- Default None. A dict with the keys being the same as in eq_params
230
- and the value being `time`, `space`, `both` or None which corresponds to
231
- the heterogeneity of a given parameter. A value can be missing, in
232
- this case there is no heterogeneity (=None). If
233
- eq_params_heterogeneity is None this means there is no
234
- heterogeneity for no parameters.
235
- """
236
- super().__init__(Tmax, eq_params_heterogeneity)
237
- self.key_main = key_main
238
- self.keys_other = keys_other
239
-
240
- def evaluate(self, t, u_dict, params_dict):
241
- """
242
- Evaluate the dynamic loss at `t`.
243
- For stability we implement the dynamic loss in log space.
244
-
245
- Parameters
246
- ---------
247
- t
248
- A time point
249
- u_dict
250
- A dictionary of PINNS. Must have the same keys as `params_dict`
251
- params_dict
252
- The dictionary of dictionaries of parameters of the model. Keys at
253
- top level are "nn_params" and "eq_params"
254
- """
255
- params_main = _extract_nn_params(params_dict, self.key_main)
256
-
257
- u = u_dict[self.key_main]
258
- # need to index with [0] since u output is nec (1,)
259
- du_dt = grad(lambda t: jnp.log(u(t, params_main)[0]), 0)(t)
260
- carrying_term = params_main["eq_params"]["carrying_capacity"] * u(
261
- t, params_main
262
- )
263
- # NOTE the following assumes interaction term with oneself is at idx 0
264
- interaction_terms = params_main["eq_params"]["interactions"][0] * u(
265
- t, params_main
266
- )
267
-
268
- # TODO write this for loop with tree_util functions?
269
- for i, k in enumerate(self.keys_other):
270
- params_k = _extract_nn_params(params_dict, k)
271
- carrying_term += params_main["eq_params"]["carrying_capacity"] * u_dict[k](
272
- t, params_k
273
- )
274
- interaction_terms += params_main["eq_params"]["interactions"][
275
- i + 1
276
- ] * u_dict[k](t, params_k)
277
-
278
- return du_dt + self.Tmax * (
279
- -params_main["eq_params"]["growth_rate"] - interaction_terms + carrying_term
280
- )
281
-
282
-
283
262
  class FPENonStatioLoss2D(PDENonStatio):
284
263
  r"""
285
264
  Return the dynamic loss for a non-stationary Fokker Planck Equation in two
286
265
  dimensions:
287
266
 
288
- .. math::
267
+ $$
289
268
  -\sum_{i=1}^2\frac{\partial}{\partial \mathbf{x}}
290
269
  \left[\mu(t, \mathbf{x})u(t, \mathbf{x})\right] +
291
270
  \sum_{i=1}^2\sum_{j=1}^2\frac{\partial^2}{\partial x_i \partial x_j}
292
271
  \left[D(t, \mathbf{x})u(t, \mathbf{x})\right]= \frac{\partial}
293
272
  {\partial t}u(t,\mathbf{x})
294
-
295
- where :math:`\mu(t, \mathbf{x})` is the drift term and :math:`D(t, \mathbf{x})` is the diffusion
296
- term.
273
+ $$
274
+ where $\mu(t, \mathbf{x})$ is the drift term and $D(t, \mathbf{x})$ is the
275
+ diffusion term.
297
276
 
298
277
  The drift and diffusion terms are not specified here, hence this class
299
278
  is `abstract`.
300
279
  Other classes inherit from FPENonStatioLoss2D and define the drift and
301
280
  diffusion terms, which then defines several other dynamic losses
302
281
  (Ornstein-Uhlenbeck, Cox-Ingersoll-Ross, ...)
303
- """
304
282
 
305
- def __init__(self, Tmax, eq_params_heterogeneity=None):
306
- """
307
- Parameters
308
- ----------
309
- Tmax
310
- Tmax needs to be given when the PINN time input is normalized in
311
- [0, 1], ie. we have performed renormalization of the differential
312
- equation
313
- eq_params_heterogeneity
314
- Default None. A dict with the keys being the same as in eq_params
315
- and the value being `time`, `space`, `both` or None which corresponds to
316
- the heterogeneity of a given parameter. A value can be missing, in
317
- this case there is no heterogeneity (=None). If
318
- eq_params_heterogeneity is None this means there is no
319
- heterogeneity for no parameters.
320
- """
321
- super().__init__(Tmax, eq_params_heterogeneity)
283
+ Parameters
284
+ ----------
285
+ Tmax
286
+ Tmax needs to be given when the PINN time input is normalized in
287
+ [0, 1], ie. we have performed renormalization of the differential
288
+ equation
289
+ eq_params_heterogeneity
290
+ Default None. A dict with the keys being the same as in eq_params
291
+ and the value being `time`, `space`, `both` or None which corresponds to
292
+ the heterogeneity of a given parameter. A value can be missing, in
293
+ this case there is no heterogeneity (=None). If
294
+ eq_params_heterogeneity is None this means there is no
295
+ heterogeneity for no parameters.
296
+ """
322
297
 
323
- def evaluate(self, t, x, u, params):
298
+ def equation(
299
+ self,
300
+ t: Float[Array, "1"],
301
+ x: Float[Array, "dim"],
302
+ u: eqx.Module,
303
+ params: Params,
304
+ ) -> Float[Array, "1"]:
324
305
  r"""
325
- Evaluate the dynamic loss at :math:`(t,\mathbf{x})`.
306
+ Evaluate the dynamic loss at $(t,\mathbf{x})$.
326
307
 
327
308
  Parameters
328
309
  ---------
329
310
  t
330
311
  A time point
331
312
  x
332
- A point in :math:`\Omega`
313
+ A point in $\Omega$
333
314
  u
334
315
  The PINN
335
316
  params
336
317
  The dictionary of parameters of the model.
337
318
  Typically, it is a dictionary of
338
- dictionaries: `eq_params` and `nn_params``, respectively the
319
+ dictionaries: `eq_params` and `nn_params`, respectively the
339
320
  differential equation parameters and the neural network parameter
340
321
  """
341
322
  if isinstance(u, PINN):
@@ -344,11 +325,13 @@ class FPENonStatioLoss2D(PDENonStatio):
344
325
 
345
326
  order_1 = (
346
327
  grad(
347
- lambda t, x: self.drift(t, x, params["eq_params"])[0] * u_(t, x),
328
+ lambda t, x: self.drift(t, x, params.eq_params)[0] * u_(t, x),
348
329
  1,
349
- )(t, x)[0:1]
330
+ )(
331
+ t, x
332
+ )[0:1]
350
333
  + grad(
351
- lambda t, x: self.drift(t, x, params["eq_params"])[1] * u_(t, x),
334
+ lambda t, x: self.drift(t, x, params.eq_params)[1] * u_(t, x),
352
335
  1,
353
336
  )(t, x)[1:2]
354
337
  )
@@ -357,7 +340,7 @@ class FPENonStatioLoss2D(PDENonStatio):
357
340
  grad(
358
341
  lambda t, x: grad(
359
342
  lambda t, x: u_(t, x)
360
- * self.diffusion(t, x, params["eq_params"])[0, 0],
343
+ * self.diffusion(t, x, params.eq_params)[0, 0],
361
344
  1,
362
345
  )(t, x)[0],
363
346
  1,
@@ -365,7 +348,7 @@ class FPENonStatioLoss2D(PDENonStatio):
365
348
  + grad(
366
349
  lambda t, x: grad(
367
350
  lambda t, x: u_(t, x)
368
- * self.diffusion(t, x, params["eq_params"])[1, 0],
351
+ * self.diffusion(t, x, params.eq_params)[1, 0],
369
352
  1,
370
353
  )(t, x)[1],
371
354
  1,
@@ -373,7 +356,7 @@ class FPENonStatioLoss2D(PDENonStatio):
373
356
  + grad(
374
357
  lambda t, x: grad(
375
358
  lambda t, x: u_(t, x)
376
- * self.diffusion(t, x, params["eq_params"])[0, 1],
359
+ * self.diffusion(t, x, params.eq_params)[0, 1],
377
360
  1,
378
361
  )(t, x)[0],
379
362
  1,
@@ -381,7 +364,7 @@ class FPENonStatioLoss2D(PDENonStatio):
381
364
  + grad(
382
365
  lambda t, x: grad(
383
366
  lambda t, x: u_(t, x)
384
- * self.diffusion(t, x, params["eq_params"])[1, 1],
367
+ * self.diffusion(t, x, params.eq_params)[1, 1],
385
368
  1,
386
369
  )(t, x)[1],
387
370
  1,
@@ -406,24 +389,20 @@ class FPENonStatioLoss2D(PDENonStatio):
406
389
  tangent_vec_0 = jnp.repeat(jnp.array([1.0, 0.0])[None], x.shape[0], axis=0)
407
390
  tangent_vec_1 = jnp.repeat(jnp.array([0.0, 1.0])[None], x.shape[0], axis=0)
408
391
  _, dau_dx1 = jax.jvp(
409
- lambda x: self.drift(t, _get_grid(x), params["eq_params"])[
410
- None, ..., 0:1
411
- ]
392
+ lambda x: self.drift(t, _get_grid(x), params.eq_params)[None, ..., 0:1]
412
393
  * u(t, x, params)[..., 0:1],
413
394
  (x,),
414
395
  (tangent_vec_0,),
415
396
  )
416
397
  _, dau_dx2 = jax.jvp(
417
- lambda x: self.drift(t, _get_grid(x), params["eq_params"])[
418
- None, ..., 1:2
419
- ]
398
+ lambda x: self.drift(t, _get_grid(x), params.eq_params)[None, ..., 1:2]
420
399
  * u(t, x, params)[..., 0:1],
421
400
  (x,),
422
401
  (tangent_vec_1,),
423
402
  )
424
403
 
425
404
  dsu_dx1_fun = lambda x, i, j: jax.jvp(
426
- lambda x: self.diffusion(t, _get_grid(x), params["eq_params"], i, j)[
405
+ lambda x: self.diffusion(t, _get_grid(x), params.eq_params, i, j)[
427
406
  None, None, None, None
428
407
  ]
429
408
  * u(t, x, params)[..., 0:1],
@@ -431,7 +410,7 @@ class FPENonStatioLoss2D(PDENonStatio):
431
410
  (tangent_vec_0,),
432
411
  )[1]
433
412
  dsu_dx2_fun = lambda x, i, j: jax.jvp(
434
- lambda x: self.diffusion(t, _get_grid(x), params["eq_params"], i, j)[
413
+ lambda x: self.diffusion(t, _get_grid(x), params.eq_params, i, j)[
435
414
  None, None, None, None
436
415
  ]
437
416
  * u(t, x, params)[..., 0:1],
@@ -459,11 +438,11 @@ class FPENonStatioLoss2D(PDENonStatio):
459
438
 
460
439
  def drift(self, *args, **kwargs):
461
440
  # To be implemented in child classes
462
- pass
441
+ raise NotImplementedError("Drift function should be implemented")
463
442
 
464
443
  def diffusion(self, *args, **kwargs):
465
444
  # To be implemented in child classes
466
- pass
445
+ raise NotImplementedError("Diffusion function should be implemented")
467
446
 
468
447
 
469
448
  class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
@@ -471,34 +450,30 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
471
450
  Return the dynamic loss for a stationary Fokker Planck Equation in two
472
451
  dimensions:
473
452
 
474
- .. math::
453
+ $$
475
454
  -\sum_{i=1}^2\frac{\partial}{\partial \mathbf{x}}
476
455
  \left[(\alpha(\mu - \mathbf{x}))u(t,\mathbf{x})\right] +
477
456
  \sum_{i=1}^2\sum_{j=1}^2\frac{\partial^2}{\partial x_i \partial x_j}
478
457
  \left[\frac{\sigma^2}{2}u(t,\mathbf{x})\right]=
479
458
  \frac{\partial}
480
459
  {\partial t}u(t,\mathbf{x})
481
-
460
+ $$
461
+
462
+ Parameters
463
+ ----------
464
+ Tmax
465
+ Tmax needs to be given when the PINN time input is normalized in
466
+ [0, 1], ie. we have performed renormalization of the differential
467
+ equation
468
+ eq_params_heterogeneity
469
+ Default None. A dict with the keys being the same as in eq_params
470
+ and the value being `time`, `space`, `both` or None which corresponds to
471
+ the heterogeneity of a given parameter. A value can be missing, in
472
+ this case there is no heterogeneity (=None). If
473
+ eq_params_heterogeneity is None this means there is no
474
+ heterogeneity for no parameters.
482
475
  """
483
476
 
484
- def __init__(self, Tmax=1, eq_params_heterogeneity=None):
485
- """
486
- Parameters
487
- ----------
488
- Tmax
489
- Tmax needs to be given when the PINN time input is normalized in
490
- [0, 1], ie. we have performed renormalization of the differential
491
- equation
492
- eq_params_heterogeneity
493
- Default None. A dict with the keys being the same as in eq_params
494
- and the value being `time`, `space`, `both` or None which corresponds to
495
- the heterogeneity of a given parameter. A value can be missing, in
496
- this case there is no heterogeneity (=None). If
497
- eq_params_heterogeneity is None this means there is no
498
- heterogeneity for no parameters.
499
- """
500
- super().__init__(Tmax, eq_params_heterogeneity)
501
-
502
477
  def drift(self, t, x, eq_params):
503
478
  r"""
504
479
  Return the drift term
@@ -508,7 +483,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
508
483
  t
509
484
  A time point
510
485
  x
511
- A point in :math:`\Omega`
486
+ A point in $\Omega$
512
487
  eq_params
513
488
  A dictionary containing the equation parameters
514
489
  """
@@ -524,7 +499,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
524
499
  t
525
500
  A time point
526
501
  x
527
- A point in :math:`\Omega`
502
+ A point in $\Omega$
528
503
  eq_params
529
504
  A dictionary containing the equation parameters
530
505
  """
@@ -541,7 +516,7 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
541
516
  t
542
517
  A time point
543
518
  x
544
- A point in :math:`\Omega`
519
+ A point in $\Omega$
545
520
  eq_params
546
521
  A dictionary containing the equation parameters
547
522
  """
@@ -564,33 +539,36 @@ class MassConservation2DStatio(PDEStatio):
564
539
  r"""
565
540
  Returns the so-called mass conservation equation.
566
541
 
567
- .. math::
542
+ $$
568
543
  \nabla \cdot \mathbf{u} = \frac{\partial}{\partial x}u(x,y) +
569
544
  \frac{\partial}{\partial y}u(x,y) = 0,
570
-
571
- where :math:`u` is a stationary function, i.e., it does not depend on
572
- :math:`t`.
545
+ $$
546
+ where $u$ is a stationary function, i.e., it does not depend on
547
+ $t$.
548
+
549
+ Parameters
550
+ ----------
551
+ nn_key
552
+ A dictionary key which identifies, in `u_dict` the PINN that
553
+ appears in the mass conservation equation.
554
+ eq_params_heterogeneity
555
+ Default None. A dict with the keys being the same as in eq_params
556
+ and the value being `time`, `space`, `both` or None which corresponds to
557
+ the heterogeneity of a given parameter. A value can be missing, in
558
+ this case there is no heterogeneity (=None). If
559
+ eq_params_heterogeneity is None this means there is no
560
+ heterogeneity for no parameters.
573
561
  """
574
562
 
575
- def __init__(self, nn_key, eq_params_heterogeneity=None):
576
- """
577
- Parameters
578
- ----------
579
- nn_key
580
- A dictionary key which identifies, in `u_dict` the PINN that
581
- appears in the mass conservation equation.
582
- eq_params_heterogeneity
583
- Default None. A dict with the keys being the same as in eq_params
584
- and the value being `time`, `space`, `both` or None which corresponds to
585
- the heterogeneity of a given parameter. A value can be missing, in
586
- this case there is no heterogeneity (=None). If
587
- eq_params_heterogeneity is None this means there is no
588
- heterogeneity for no parameters.
589
- """
590
- self.nn_key = nn_key
591
- super().__init__(eq_params_heterogeneity)
563
+ # an str field should be static (not a valid JAX type)
564
+ nn_key: str = eqx.field(static=True)
592
565
 
593
- def evaluate(self, x, u_dict, params_dict):
566
+ def equation(
567
+ self,
568
+ x: Float[Array, "dim"],
569
+ u_dict: Dict[str, eqx.Module],
570
+ params_dict: ParamsDict,
571
+ ) -> Float[Array, "1"]:
594
572
  r"""
595
573
  Evaluate the dynamic loss at `\mathbf{x}`.
596
574
  For stability we implement the dynamic loss in log space.
@@ -598,17 +576,17 @@ class MassConservation2DStatio(PDEStatio):
598
576
  Parameters
599
577
  ---------
600
578
  x
601
- A point in :math:`\Omega\subset\mathbb{R}^2`
579
+ A point in $\Omega\subset\mathbb{R}^2$
602
580
  u_dict
603
581
  A dictionary of PINNs. Must have the same keys as `params_dict`
604
582
  params_dict
605
583
  The dictionary of dictionaries of parameters of the model.
606
584
  Typically, each sub-dictionary is a dictionary
607
- with keys: `eq_params` and `nn_params``, respectively the
585
+ with keys: `eq_params` and `nn_params`, respectively the
608
586
  differential equation parameters and the neural network parameter.
609
587
  Must have the same keys as `u_dict`
610
588
  """
611
- params = _extract_nn_params(params_dict, self.nn_key)
589
+ params = params_dict.extract_params(self.nn_key)
612
590
 
613
591
  if isinstance(u_dict[self.nn_key], PINN):
614
592
  u = u_dict[self.nn_key]
@@ -627,15 +605,15 @@ class NavierStokes2DStatio(PDEStatio):
627
605
  Return the dynamic loss for all the components of the stationary Navier Stokes
628
606
  equation which is a 2D vectorial PDE.
629
607
 
630
- .. math::
608
+ $$
631
609
  (\mathbf{u}\cdot\nabla)\mathbf{u} + \frac{1}{\rho}\nabla p - \theta
632
610
  \nabla^2\mathbf{u}=0,
633
-
611
+ $$
634
612
 
635
613
  or, in 2D,
636
614
 
637
615
 
638
- .. math::
616
+ $$
639
617
  \begin{pmatrix}u_x\frac{\partial}{\partial x} u_x + u_y\frac{\partial}{\partial y} u_x \\
640
618
  u_x\frac{\partial}{\partial x} u_y + u_y\frac{\partial}{\partial y} u_y \end{pmatrix} +
641
619
  \frac{1}{\rho} \begin{pmatrix} \frac{\partial}{\partial x} p \\ \frac{\partial}{\partial y} p \end{pmatrix}
@@ -644,61 +622,60 @@ class NavierStokes2DStatio(PDEStatio):
644
622
  \frac{\partial^2}{\partial x^2} u_x + \frac{\partial^2}{\partial y^2} u_x \\
645
623
  \frac{\partial^2}{\partial x^2} u_y + \frac{\partial^2}{\partial y^2} u_y
646
624
  \end{pmatrix} = 0,
647
-
625
+ $$
648
626
  with $\theta$ the viscosity coefficient and $\rho$ the density coefficient.
649
627
 
650
- **Note:** Note that the solution to the Navier Stokes equation is a vector
651
- field. Hence the MSE must concern all the axes.
628
+ Parameters
629
+ ----------
630
+ u_key
631
+ A dictionary key which indices the NN u in `u_dict`
632
+ the PINN with the role of the velocity in the equation.
633
+ Its input is bimensional (points in $\Omega\subset\mathbb{R}^2$).
634
+ Its output is bimensional as it represents a velocity vector
635
+ field
636
+ p_key
637
+ A dictionary key which indices the NN p in `u_dict`
638
+ the PINN with the role of the pressure in the equation.
639
+ Its input is bimensional (points in $\Omega\subset\mathbb{R}^2).
640
+ Its output is unidimensional as it represents a pressure scalar
641
+ field
642
+ eq_params_heterogeneity
643
+ Default None. A dict with the keys being the same as in eq_params
644
+ and the value being `time`, `space`, `both` or None which corresponds to
645
+ the heterogeneity of a given parameter. A value can be missing, in
646
+ this case there is no heterogeneity (=None). If
647
+ eq_params_heterogeneity is None this means there is no
648
+ heterogeneity for no parameters.
652
649
  """
653
650
 
654
- def __init__(self, u_key, p_key, eq_params_heterogeneity=None):
655
- r"""
656
- Parameters
657
- ----------
658
- u_key
659
- A dictionary key which indices in `u_dict`
660
- the PINN with the role of the velocity in the equation.
661
- Its input is bimensional (points in :math:`\Omega\subset\mathbb{R}^2`).
662
- Its output is bimensional as it represents a velocity vector
663
- field
664
- p_key
665
- A dictionary key which indices in `u_dict`
666
- the PINN with the role of the pressure in the equation.
667
- Its input is bimensional (points in :math:`\Omega\subset\mathbb{R}^2`).
668
- Its output is unidimensional as it represents a pressure scalar
669
- field
670
- eq_params_heterogeneity
671
- Default None. A dict with the keys being the same as in eq_params
672
- and the value being `time`, `space`, `both` or None which corresponds to
673
- the heterogeneity of a given parameter. A value can be missing, in
674
- this case there is no heterogeneity (=None). If
675
- eq_params_heterogeneity is None this means there is no
676
- heterogeneity for no parameters.
677
- """
678
- self.u_key = u_key
679
- self.p_key = p_key
680
- super().__init__(eq_params_heterogeneity)
651
+ u_key: str = eqx.field(static=True)
652
+ p_key: str = eqx.field(static=True)
681
653
 
682
- def evaluate(self, x, u_dict, params_dict):
654
+ def equation(
655
+ self,
656
+ x: Float[Array, "dim"],
657
+ u_dict: Dict[str, eqx.Module],
658
+ params_dict: ParamsDict,
659
+ ) -> Float[Array, "1"]:
683
660
  r"""
684
- Evaluate the dynamic loss at `\mathbf{x}`.
661
+ Evaluate the dynamic loss at `x`.
685
662
  For stability we implement the dynamic loss in log space.
686
663
 
687
664
  Parameters
688
665
  ---------
689
666
  x
690
- A point in :math:`\Omega\subset\mathbb{R}^2`
667
+ A point in $\Omega\subset\mathbb{R}^2$
691
668
  u_dict
692
669
  A dictionary of PINNs. Must have the same keys as `params_dict`
693
670
  params_dict
694
671
  The dictionary of dictionaries of parameters of the model.
695
672
  Typically, each sub-dictionary is a dictionary
696
- with keys: `eq_params` and `nn_params``, respectively the
673
+ with keys: `eq_params` and `nn_params`, respectively the
697
674
  differential equation parameters and the neural network parameter.
698
675
  Must have the same keys as `u_dict`
699
676
  """
700
- u_params = _extract_nn_params(params_dict, self.u_key)
701
- p_params = _extract_nn_params(params_dict, self.p_key)
677
+ u_params = params_dict.extract_params(self.u_key)
678
+ p_params = params_dict.extract_params(self.p_key)
702
679
 
703
680
  if isinstance(u_dict[self.u_key], PINN):
704
681
  u = u_dict[self.u_key]
@@ -706,22 +683,22 @@ class NavierStokes2DStatio(PDEStatio):
706
683
  u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(None, x, u, u_params)
707
684
 
708
685
  p = lambda x: u_dict[self.p_key](x, p_params)
709
- jac_p = jacrev(p, 0)(x) # compute the gradient
686
+ jac_p = jax.jacrev(p, 0)(x) # compute the gradient
710
687
 
711
688
  vec_laplacian_u = _vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2)
712
689
 
713
690
  # dynamic loss on x axis
714
691
  result_x = (
715
692
  u_dot_nabla_x_u[0]
716
- + 1 / params_dict["eq_params"]["rho"] * jac_p[0, 0]
717
- - params_dict["eq_params"]["nu"] * vec_laplacian_u[0]
693
+ + 1 / params_dict.eq_params["rho"] * jac_p[0, 0]
694
+ - params_dict.eq_params["nu"] * vec_laplacian_u[0]
718
695
  )
719
696
 
720
697
  # dynamic loss on y axis
721
698
  result_y = (
722
699
  u_dot_nabla_x_u[1]
723
- + 1 / params_dict["eq_params"]["rho"] * jac_p[0, 1]
724
- - params_dict["eq_params"]["nu"] * vec_laplacian_u[1]
700
+ + 1 / params_dict.eq_params["rho"] * jac_p[0, 1]
701
+ - params_dict.eq_params["nu"] * vec_laplacian_u[1]
725
702
  )
726
703
 
727
704
  # output is 2D
@@ -748,14 +725,14 @@ class NavierStokes2DStatio(PDEStatio):
748
725
  # dynamic loss on x axis
749
726
  result_x = (
750
727
  u_dot_nabla_x_u[..., 0]
751
- + 1 / params_dict["eq_params"]["rho"] * dp_dx.squeeze()
752
- - params_dict["eq_params"]["nu"] * vec_laplacian_u[..., 0]
728
+ + 1 / params_dict.eq_params["rho"] * dp_dx.squeeze()
729
+ - params_dict.eq_params["nu"] * vec_laplacian_u[..., 0]
753
730
  )
754
731
  # dynamic loss on y axis
755
732
  result_y = (
756
733
  u_dot_nabla_x_u[..., 1]
757
- + 1 / params_dict["eq_params"]["rho"] * dp_dy.squeeze()
758
- - params_dict["eq_params"]["nu"] * vec_laplacian_u[..., 1]
734
+ + 1 / params_dict.eq_params["rho"] * dp_dy.squeeze()
735
+ - params_dict.eq_params["nu"] * vec_laplacian_u[..., 1]
759
736
  )
760
737
 
761
738
  # output is 2D