jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -2,219 +2,374 @@
|
|
|
2
2
|
Implements abstract classes for dynamic losses
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
5
8
|
|
|
6
|
-
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
from typing import Callable, Dict, TYPE_CHECKING, ClassVar
|
|
11
|
+
from jaxtyping import Float, Array
|
|
12
|
+
from functools import partial
|
|
13
|
+
import abc
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from typing import ClassVar as AbstractClassVar
|
|
19
|
+
from jinns.parameters import Params, ParamsDict
|
|
20
|
+
else:
|
|
21
|
+
from equinox import AbstractClassVar
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _decorator_heteregeneous_params(evaluate, eq_type):
|
|
25
|
+
|
|
26
|
+
def wrapper_ode(*args):
|
|
27
|
+
self, t, u, params = args
|
|
28
|
+
_params = eqx.tree_at(
|
|
29
|
+
lambda p: p.eq_params,
|
|
30
|
+
params,
|
|
31
|
+
self._eval_heterogeneous_parameters(
|
|
32
|
+
t, None, u, params, self.eq_params_heterogeneity
|
|
33
|
+
),
|
|
34
|
+
)
|
|
35
|
+
new_args = args[:-1] + (_params,)
|
|
36
|
+
res = evaluate(*new_args)
|
|
37
|
+
return res
|
|
38
|
+
|
|
39
|
+
def wrapper_pde_statio(*args):
|
|
40
|
+
self, x, u, params = args
|
|
41
|
+
_params = eqx.tree_at(
|
|
42
|
+
lambda p: p.eq_params,
|
|
43
|
+
params,
|
|
44
|
+
self._eval_heterogeneous_parameters(
|
|
45
|
+
None, x, u, params, self.eq_params_heterogeneity
|
|
46
|
+
),
|
|
47
|
+
)
|
|
48
|
+
new_args = args[:-1] + (_params,)
|
|
49
|
+
res = evaluate(*new_args)
|
|
50
|
+
return res
|
|
51
|
+
|
|
52
|
+
def wrapper_pde_non_statio(*args):
|
|
53
|
+
self, t, x, u, params = args
|
|
54
|
+
_params = eqx.tree_at(
|
|
55
|
+
lambda p: p.eq_params,
|
|
56
|
+
params,
|
|
57
|
+
self._eval_heterogeneous_parameters(
|
|
58
|
+
t, x, u, params, self.eq_params_heterogeneity
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
new_args = args[:-1] + (_params,)
|
|
62
|
+
res = evaluate(*new_args)
|
|
63
|
+
return res
|
|
64
|
+
|
|
65
|
+
if eq_type == "ODE":
|
|
66
|
+
return wrapper_ode
|
|
67
|
+
elif eq_type == "Statio PDE":
|
|
68
|
+
return wrapper_pde_statio
|
|
69
|
+
elif eq_type == "Non-statio PDE":
|
|
70
|
+
return wrapper_pde_non_statio
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class DynamicLoss(eqx.Module):
|
|
7
74
|
r"""
|
|
8
|
-
Abstract base class for dynamic losses
|
|
75
|
+
Abstract base class for dynamic losses. Implements the physical term:
|
|
9
76
|
|
|
10
|
-
|
|
77
|
+
$$
|
|
11
78
|
\mathcal{N}[u](t, x) = 0
|
|
79
|
+
$$
|
|
80
|
+
|
|
81
|
+
for **one** point $t$, $x$ or $(t, x)$, depending on the context.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
Tmax : Float, default=1
|
|
86
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
87
|
+
[0, 1], ie. we have performed renormalization of the differential
|
|
88
|
+
equation
|
|
89
|
+
eq_params_heterogeneity : Dict[str, Callable | None], default=None
|
|
90
|
+
A dict with the same keys as eq_params and the value being either None
|
|
91
|
+
(no heterogeneity) or a function which encodes for the spatio-temporal
|
|
92
|
+
heterogeneity of the parameter.
|
|
93
|
+
Such a function must be jittable and take four arguments `t`, `x`,
|
|
94
|
+
`u` and `params` even if some are not used. Therefore,
|
|
95
|
+
one can introduce spatio-temporal covariates upon which a particular
|
|
96
|
+
parameter can depend, e.g. in a Generalized Linear Model fashion. The
|
|
97
|
+
effect of these covariates can themselves be estimated by being in
|
|
98
|
+
`eq_params` too.
|
|
99
|
+
A value can be missing, in this case there is no heterogeneity (=None).
|
|
100
|
+
Default None, meaning there is no heterogeneity in the equation
|
|
101
|
+
parameters.
|
|
12
102
|
"""
|
|
13
103
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
[0, 1], ie. we have performed renormalization of the differential
|
|
21
|
-
equation
|
|
22
|
-
eq_params_heterogeneity
|
|
23
|
-
Default None. A dict with the keys being the same as in eq_params
|
|
24
|
-
and the value being either None (no heterogeneity) or a function
|
|
25
|
-
which encodes for the spatio-temporal heterogeneity of the parameter.
|
|
26
|
-
Such a function must be jittable and take four arguments `t`, `x`,
|
|
27
|
-
`u` and `params` even if one is not used. Therefore,
|
|
28
|
-
one can introduce spatio-temporal covariates upon which a particular
|
|
29
|
-
parameter can depend, e.g. in a GLM fashion. The effect of these
|
|
30
|
-
covariables can themselves be estimated by being in `eq_params` too.
|
|
31
|
-
A value can be missing, in this case there is no heterogeneity (=None).
|
|
32
|
-
If eq_params_heterogeneity is None this means there is no
|
|
33
|
-
heterogeneity for no parameters.
|
|
34
|
-
"""
|
|
35
|
-
self.Tmax = Tmax
|
|
36
|
-
self.eq_params_heterogeneity = eq_params_heterogeneity
|
|
104
|
+
_eq_type = AbstractClassVar[str] # class variable denoting the type of
|
|
105
|
+
# differential equation
|
|
106
|
+
Tmax: Float = eqx.field(kw_only=True, default=1)
|
|
107
|
+
eq_params_heterogeneity: Dict[str, Callable | None] = eqx.field(
|
|
108
|
+
kw_only=True, default=None, static=True
|
|
109
|
+
)
|
|
37
110
|
|
|
38
|
-
|
|
39
|
-
|
|
111
|
+
def _eval_heterogeneous_parameters(
|
|
112
|
+
self,
|
|
113
|
+
t: Float[Array, "1"],
|
|
114
|
+
x: Float[Array, "dim"],
|
|
115
|
+
u: eqx.Module,
|
|
116
|
+
params: Params | ParamsDict,
|
|
117
|
+
eq_params_heterogeneity: Dict[str, Callable | None] = None,
|
|
118
|
+
) -> Dict[str, float | Float[Array, "parameter_dimension"]]:
|
|
40
119
|
eq_params_ = {}
|
|
41
120
|
if eq_params_heterogeneity is None:
|
|
42
|
-
return params
|
|
43
|
-
for k, p in params
|
|
121
|
+
return params.eq_params
|
|
122
|
+
for k, p in params.eq_params.items():
|
|
44
123
|
try:
|
|
45
124
|
if eq_params_heterogeneity[k] is None:
|
|
46
125
|
eq_params_[k] = p
|
|
47
126
|
else:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
eq_params_[k] = eq_params_heterogeneity[k](
|
|
54
|
-
|
|
55
|
-
)
|
|
127
|
+
# heterogeneity encoded through a function whose
|
|
128
|
+
# signature will vary according to _eq_type
|
|
129
|
+
if self._eq_type == "ODE":
|
|
130
|
+
eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
|
|
131
|
+
elif self._eq_type == "Statio PDE":
|
|
132
|
+
eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
|
|
133
|
+
elif self._eq_type == "Non-statio PDE":
|
|
134
|
+
eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
|
|
56
135
|
except KeyError:
|
|
57
136
|
# we authorize missing eq_params_heterogeneity key
|
|
58
|
-
#
|
|
137
|
+
# if its heterogeneity is None anyway
|
|
59
138
|
eq_params_[k] = p
|
|
60
139
|
return eq_params_
|
|
61
140
|
|
|
141
|
+
def _evaluate(
|
|
142
|
+
self,
|
|
143
|
+
t: Float[Array, "1"],
|
|
144
|
+
x: Float[Array, "dim"],
|
|
145
|
+
u: eqx.Module,
|
|
146
|
+
params: Params | ParamsDict,
|
|
147
|
+
) -> float:
|
|
148
|
+
# Here we handle the various possible signature
|
|
149
|
+
if self._eq_type == "ODE":
|
|
150
|
+
ans = self.equation(t, u, params)
|
|
151
|
+
elif self._eq_type == "Statio PDE":
|
|
152
|
+
ans = self.equation(x, u, params)
|
|
153
|
+
elif self._eq_type == "Non-statio PDE":
|
|
154
|
+
ans = self.equation(t, x, u, params)
|
|
155
|
+
else:
|
|
156
|
+
raise NotImplementedError("the equation type is not handled.")
|
|
157
|
+
|
|
158
|
+
return ans
|
|
159
|
+
|
|
160
|
+
@abc.abstractmethod
|
|
161
|
+
def equation(self, *args, **kwargs):
|
|
162
|
+
# TO IMPLEMENT
|
|
163
|
+
# Point-wise evaluation of the differential equation N[u](.)
|
|
164
|
+
raise NotImplementedError("You should implement your equation.")
|
|
165
|
+
|
|
62
166
|
|
|
63
167
|
class ODE(DynamicLoss):
|
|
64
168
|
r"""
|
|
65
|
-
Abstract base class for ODE dynamic losses
|
|
169
|
+
Abstract base class for ODE dynamic losses. All dynamic loss must subclass
|
|
170
|
+
this class and override the abstract method `equation`.
|
|
171
|
+
|
|
172
|
+
Attributes
|
|
173
|
+
----------
|
|
174
|
+
Tmax : float, default=1
|
|
175
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
176
|
+
[0, 1], ie. we have performed renormalization of the differential
|
|
177
|
+
equation
|
|
178
|
+
eq_params_heterogeneity : Dict[str, Callable | None], default=None
|
|
179
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
180
|
+
and the value being either None (no heterogeneity) or a function
|
|
181
|
+
which encodes for the spatio-temporal heterogeneity of the parameter.
|
|
182
|
+
Such a function must be jittable and take four arguments `t`, `x`,
|
|
183
|
+
`u` and `params` even if one is not used. Therefore,
|
|
184
|
+
one can introduce spatio-temporal covariates upon which a particular
|
|
185
|
+
parameter can depend, e.g. in a GLM fashion. The effect of these
|
|
186
|
+
covariables can themselves be estimated by being in `eq_params` too.
|
|
187
|
+
Some key can be missing, in this case there is no heterogeneity (=None).
|
|
188
|
+
If eq_params_heterogeneity is None this means there is no
|
|
189
|
+
heterogeneity for no parameters.
|
|
66
190
|
"""
|
|
67
191
|
|
|
68
|
-
|
|
69
|
-
|
|
192
|
+
_eq_type: ClassVar[str] = "ODE"
|
|
193
|
+
|
|
194
|
+
@partial(_decorator_heteregeneous_params, eq_type="ODE")
|
|
195
|
+
def evaluate(
|
|
196
|
+
self,
|
|
197
|
+
t: Float[Array, "1"],
|
|
198
|
+
u: eqx.Module | Dict[str, eqx.Module],
|
|
199
|
+
params: Params | ParamsDict,
|
|
200
|
+
) -> float:
|
|
201
|
+
"""Here we call DynamicLoss._evaluate with x=None"""
|
|
202
|
+
return self._evaluate(t, None, u, params)
|
|
203
|
+
|
|
204
|
+
@abc.abstractmethod
|
|
205
|
+
def equation(
|
|
206
|
+
self, t: Float[Array, "1"], u: eqx.Module, params: Params | ParamsDict
|
|
207
|
+
) -> float:
|
|
208
|
+
r"""
|
|
209
|
+
The differential operator defining the ODE.
|
|
210
|
+
|
|
211
|
+
!!! warning
|
|
212
|
+
|
|
213
|
+
This is an abstract method to be implemented by users.
|
|
214
|
+
|
|
70
215
|
Parameters
|
|
71
216
|
----------
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
and the value being `time`, `space`, `both` or None which corresponds to
|
|
79
|
-
the heterogeneity of a given parameter. A value can be missing, in
|
|
80
|
-
this case there is no heterogeneity (=None). If
|
|
81
|
-
eq_params_heterogeneity is None this means there is no
|
|
82
|
-
heterogeneity for no parameters.
|
|
83
|
-
"""
|
|
84
|
-
super().__init__(Tmax, eq_params_heterogeneity)
|
|
217
|
+
t : Float[Array, "1"]
|
|
218
|
+
A 1-dimensional jnp.array representing the time point.
|
|
219
|
+
u : eqx.Module
|
|
220
|
+
The network with a call signature `u(t, params)`.
|
|
221
|
+
params : Params | ParamsDict
|
|
222
|
+
The equation and neural network parameters $\theta$ and $\nu$.
|
|
85
223
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
float
|
|
227
|
+
The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t)$ evaluated at point `t`.
|
|
90
228
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
in order. It calls _eval_heterogeneous_parameters which applies the
|
|
96
|
-
user defined rules to obtain spatially / temporally heterogeneous
|
|
97
|
-
parameters
|
|
229
|
+
Raises
|
|
230
|
+
------
|
|
231
|
+
NotImplementedError
|
|
232
|
+
This is an abstract method to be implemented.
|
|
98
233
|
"""
|
|
99
|
-
|
|
100
|
-
def wrapper(*args):
|
|
101
|
-
self, t, u, params = args
|
|
102
|
-
# avoid side effect with in-place modif of param["eq_params"]
|
|
103
|
-
# TODO NamedTuple for params and use _replace() see Issue 1
|
|
104
|
-
_params = {
|
|
105
|
-
"nn_params": params["nn_params"],
|
|
106
|
-
"eq_params": self.eval_heterogeneous_parameters(
|
|
107
|
-
t, u, params, self.eq_params_heterogeneity
|
|
108
|
-
),
|
|
109
|
-
}
|
|
110
|
-
new_args = args[:-1] + (_params,)
|
|
111
|
-
res = evaluate(*new_args)
|
|
112
|
-
return res
|
|
113
|
-
|
|
114
|
-
return wrapper
|
|
234
|
+
raise NotImplementedError
|
|
115
235
|
|
|
116
236
|
|
|
117
237
|
class PDEStatio(DynamicLoss):
|
|
118
238
|
r"""
|
|
119
|
-
Abstract base class for PDE
|
|
239
|
+
Abstract base class for stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
|
|
240
|
+
|
|
241
|
+
Attributes
|
|
242
|
+
----------
|
|
243
|
+
Tmax : float, default=1
|
|
244
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
245
|
+
[0, 1], ie. we have performed renormalization of the differential
|
|
246
|
+
equation
|
|
247
|
+
eq_params_heterogeneity : Dict[str, Callable | None], default=None
|
|
248
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
249
|
+
and the value being either None (no heterogeneity) or a function
|
|
250
|
+
which encodes for the spatio-temporal heterogeneity of the parameter.
|
|
251
|
+
Such a function must be jittable and take four arguments `t`, `x`,
|
|
252
|
+
`u` and `params` even if one is not used. Therefore,
|
|
253
|
+
one can introduce spatio-temporal covariates upon which a particular
|
|
254
|
+
parameter can depend, e.g. in a GLM fashion. The effect of these
|
|
255
|
+
covariables can themselves be estimated by being in `eq_params` too.
|
|
256
|
+
Some key can be missing, in this case there is no heterogeneity (=None).
|
|
257
|
+
If eq_params_heterogeneity is None this means there is no
|
|
258
|
+
heterogeneity for no parameters.
|
|
120
259
|
"""
|
|
121
260
|
|
|
122
|
-
|
|
123
|
-
|
|
261
|
+
_eq_type: ClassVar[str] = "Statio PDE"
|
|
262
|
+
|
|
263
|
+
@partial(_decorator_heteregeneous_params, eq_type="Statio PDE")
|
|
264
|
+
def evaluate(
|
|
265
|
+
self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
|
|
266
|
+
) -> float:
|
|
267
|
+
"""Here we call the DynamicLoss._evaluate with t=None"""
|
|
268
|
+
return self._evaluate(None, x, u, params)
|
|
269
|
+
|
|
270
|
+
@abc.abstractmethod
|
|
271
|
+
def equation(
|
|
272
|
+
self, x: Float[Array, "d"], u: eqx.Module, params: Params | ParamsDict
|
|
273
|
+
) -> float:
|
|
274
|
+
r"""The differential operator defining the stationnary PDE.
|
|
275
|
+
|
|
276
|
+
!!! warning
|
|
277
|
+
|
|
278
|
+
This is an abstract method to be implemented by users.
|
|
279
|
+
|
|
124
280
|
Parameters
|
|
125
281
|
----------
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
heterogeneity for no parameters.
|
|
133
|
-
"""
|
|
134
|
-
super().__init__(eq_params_heterogeneity=eq_params_heterogeneity)
|
|
282
|
+
x : Float[Array, "d"]
|
|
283
|
+
A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
|
|
284
|
+
u : eqx.Module
|
|
285
|
+
The neural network.
|
|
286
|
+
params : Params | ParamsDict
|
|
287
|
+
The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
|
|
135
288
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
289
|
+
Returns
|
|
290
|
+
-------
|
|
291
|
+
float
|
|
292
|
+
The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](x)$ evaluated at point `x`.
|
|
140
293
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
in order. It calls _eval_heterogeneous_parameters which applies the
|
|
146
|
-
user defined rules to obtain spatially / temporally heterogeneous
|
|
147
|
-
parameters
|
|
294
|
+
Raises
|
|
295
|
+
------
|
|
296
|
+
NotImplementedError
|
|
297
|
+
This is an abstract method to be implemented.
|
|
148
298
|
"""
|
|
149
|
-
|
|
150
|
-
def wrapper(*args):
|
|
151
|
-
self, x, u, params = args
|
|
152
|
-
# avoid side effect with in-place modif of param["eq_params"]
|
|
153
|
-
# TODO NamedTuple for params and use _replace() see Issue 1
|
|
154
|
-
_params = {
|
|
155
|
-
"nn_params": params["nn_params"],
|
|
156
|
-
"eq_params": self.eval_heterogeneous_parameters(
|
|
157
|
-
x, u, params, self.eq_params_heterogeneity
|
|
158
|
-
),
|
|
159
|
-
}
|
|
160
|
-
new_args = args[:-1] + (_params,)
|
|
161
|
-
res = evaluate(*new_args)
|
|
162
|
-
return res
|
|
163
|
-
|
|
164
|
-
return wrapper
|
|
299
|
+
raise NotImplementedError
|
|
165
300
|
|
|
166
301
|
|
|
167
302
|
class PDENonStatio(DynamicLoss):
|
|
168
|
-
r"""
|
|
169
|
-
Abstract base class for PDE Non statio dynamic losses
|
|
170
303
|
"""
|
|
304
|
+
Abstract base class for non-stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
|
|
305
|
+
|
|
306
|
+
Attributes
|
|
307
|
+
----------
|
|
308
|
+
Tmax : float, default=1
|
|
309
|
+
Tmax needs to be given when the PINN time input is normalized in
|
|
310
|
+
[0, 1], ie. we have performed renormalization of the differential
|
|
311
|
+
equation
|
|
312
|
+
eq_params_heterogeneity : Dict[str, Callable | None], default=None
|
|
313
|
+
Default None. A dict with the keys being the same as in eq_params
|
|
314
|
+
and the value being either None (no heterogeneity) or a function
|
|
315
|
+
which encodes for the spatio-temporal heterogeneity of the parameter.
|
|
316
|
+
Such a function must be jittable and take four arguments `t`, `x`,
|
|
317
|
+
`u` and `params` even if one is not used. Therefore,
|
|
318
|
+
one can introduce spatio-temporal covariates upon which a particular
|
|
319
|
+
parameter can depend, e.g. in a GLM fashion. The effect of these
|
|
320
|
+
covariables can themselves be estimated by being in `eq_params` too.
|
|
321
|
+
Some key can be missing, in this case there is no heterogeneity (=None).
|
|
322
|
+
If eq_params_heterogeneity is None this means there is no
|
|
323
|
+
heterogeneity for no parameters.
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
_eq_type: ClassVar[str] = "Non-statio PDE"
|
|
327
|
+
|
|
328
|
+
@partial(_decorator_heteregeneous_params, eq_type="Non-statio PDE")
|
|
329
|
+
def evaluate(
|
|
330
|
+
self,
|
|
331
|
+
t: Float[Array, "1"],
|
|
332
|
+
x: Float[Array, "dim"],
|
|
333
|
+
u: eqx.Module,
|
|
334
|
+
params: Params | ParamsDict,
|
|
335
|
+
) -> float:
|
|
336
|
+
"""Here we call the DynamicLoss._evaluate with full arguments"""
|
|
337
|
+
ans = self._evaluate(t, x, u, params)
|
|
338
|
+
return ans
|
|
339
|
+
|
|
340
|
+
@abc.abstractmethod
|
|
341
|
+
def equation(
|
|
342
|
+
self,
|
|
343
|
+
t: Float[Array, "1"],
|
|
344
|
+
x: Float[Array, "dim"],
|
|
345
|
+
u: eqx.Module,
|
|
346
|
+
params: Params | ParamsDict,
|
|
347
|
+
) -> float:
|
|
348
|
+
r"""The differential operator defining the non-stationnary PDE.
|
|
349
|
+
|
|
350
|
+
!!! warning
|
|
351
|
+
|
|
352
|
+
This is an abstract method to be implemented by users.
|
|
171
353
|
|
|
172
|
-
def __init__(self, Tmax=None, eq_params_heterogeneity=None):
|
|
173
|
-
"""
|
|
174
354
|
Parameters
|
|
175
355
|
----------
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
super().__init__(Tmax, eq_params_heterogeneity)
|
|
356
|
+
t : Float[Array, "1"]
|
|
357
|
+
A 1-dimensional jnp.array representing the time point.
|
|
358
|
+
x : Float[Array, "d"]
|
|
359
|
+
A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
|
|
360
|
+
u : eqx.Module
|
|
361
|
+
The neural network.
|
|
362
|
+
params : Params | ParamsDict
|
|
363
|
+
The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
float
|
|
367
|
+
The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t, x)$ evaluated at point `(t, x)`.
|
|
189
368
|
|
|
190
|
-
def eval_heterogeneous_parameters(
|
|
191
|
-
self, t, x, u, params, eq_params_heterogeneity=None
|
|
192
|
-
):
|
|
193
|
-
return super()._eval_heterogeneous_parameters(
|
|
194
|
-
t, x, u, params, eq_params_heterogeneity
|
|
195
|
-
)
|
|
196
369
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
in order. It calls _eval_heterogeneous_parameters which applies the
|
|
202
|
-
user defined rules to obtain spatially / temporally heterogeneous
|
|
203
|
-
parameters
|
|
370
|
+
Raises
|
|
371
|
+
------
|
|
372
|
+
NotImplementedError
|
|
373
|
+
This is an abstract method to be implemented.
|
|
204
374
|
"""
|
|
205
|
-
|
|
206
|
-
def wrapper(*args):
|
|
207
|
-
self, t, x, u, params = args
|
|
208
|
-
# avoid side effect with in-place modif of param["eq_params"]
|
|
209
|
-
# TODO NamedTuple for params and use _replace() see Issue 1
|
|
210
|
-
_params = {
|
|
211
|
-
"nn_params": params["nn_params"],
|
|
212
|
-
"eq_params": self.eval_heterogeneous_parameters(
|
|
213
|
-
t, x, u, params, self.eq_params_heterogeneity
|
|
214
|
-
),
|
|
215
|
-
}
|
|
216
|
-
new_args = args[:-1] + (_params,)
|
|
217
|
-
res = evaluate(*new_args)
|
|
218
|
-
return res
|
|
219
|
-
|
|
220
|
-
return wrapper
|
|
375
|
+
raise NotImplementedError
|