jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Formalize the data structure for the derivative keys
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
from dataclasses import fields, InitVar
|
|
7
|
+
from typing import Literal
|
|
8
|
+
import jax
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
|
|
11
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_masked_parameters(
|
|
15
|
+
derivative_mask_str: str, params: Params | ParamsDict
|
|
16
|
+
) -> Params | ParamsDict:
|
|
17
|
+
"""
|
|
18
|
+
Creates the Params object with True values where we want to differentiate
|
|
19
|
+
"""
|
|
20
|
+
if isinstance(params, Params):
|
|
21
|
+
# start with a params object with True everywhere. We will update to False
|
|
22
|
+
# for parameters wrt which we do want not to differentiate the loss
|
|
23
|
+
diff_params = jax.tree.map(
|
|
24
|
+
lambda x: True,
|
|
25
|
+
params,
|
|
26
|
+
is_leaf=lambda x: isinstance(x, eqx.Module)
|
|
27
|
+
and not isinstance(x, Params), # do not travers nn_params, more
|
|
28
|
+
# granularity could be imagined here, in the future
|
|
29
|
+
)
|
|
30
|
+
if derivative_mask_str == "both":
|
|
31
|
+
return diff_params
|
|
32
|
+
if derivative_mask_str == "eq_params":
|
|
33
|
+
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
34
|
+
if derivative_mask_str == "nn_params":
|
|
35
|
+
return eqx.tree_at(
|
|
36
|
+
lambda p: p.eq_params,
|
|
37
|
+
diff_params,
|
|
38
|
+
jax.tree.map(lambda x: False, params.eq_params),
|
|
39
|
+
)
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"Bad value for DerivativeKeys. Got "
|
|
42
|
+
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
43
|
+
' "eq_params"'
|
|
44
|
+
)
|
|
45
|
+
elif isinstance(params, ParamsDict):
|
|
46
|
+
# do not travers nn_params, more
|
|
47
|
+
# granularity could be imagined here, in the future
|
|
48
|
+
diff_params = ParamsDict(
|
|
49
|
+
nn_params=True, eq_params=jax.tree.map(lambda x: True, params.eq_params)
|
|
50
|
+
)
|
|
51
|
+
if derivative_mask_str == "both":
|
|
52
|
+
return diff_params
|
|
53
|
+
if derivative_mask_str == "eq_params":
|
|
54
|
+
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
55
|
+
if derivative_mask_str == "nn_params":
|
|
56
|
+
return eqx.tree_at(
|
|
57
|
+
lambda p: p.eq_params,
|
|
58
|
+
diff_params,
|
|
59
|
+
jax.tree.map(lambda x: False, params.eq_params),
|
|
60
|
+
)
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"Bad value for DerivativeKeys. Got "
|
|
63
|
+
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
64
|
+
' "eq_params"'
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Bad value for params. Got {type(params)}, expected Params "
|
|
70
|
+
" or ParamsDict"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class DerivativeKeysODE(eqx.Module):
|
|
75
|
+
"""
|
|
76
|
+
A class that specifies with repect to which parameter(s) each term of the
|
|
77
|
+
loss is differentiated. For example, you can specify that the
|
|
78
|
+
[`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
|
|
79
|
+
respect to the neural network parameters *and* the equation parameters, or only some of them.
|
|
80
|
+
|
|
81
|
+
To do so, user can either use strings or a `Params` object
|
|
82
|
+
with PyTree structure matching the parameters of the problem at
|
|
83
|
+
hand, and booleans indicating if gradient is to be taken or not. Internally,
|
|
84
|
+
a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
|
|
85
|
+
computing each loss term.
|
|
86
|
+
|
|
87
|
+
!!! note
|
|
88
|
+
|
|
89
|
+
1. For unspecified loss term, the default is to differentiate with
|
|
90
|
+
respect to `"nn_params"` only.
|
|
91
|
+
2. No granularity inside `Params.nn_params` is currently supported.
|
|
92
|
+
3. Note that the main Params or ParamsDict object of the problem is mandatory if initialization via `from_str()`.
|
|
93
|
+
|
|
94
|
+
A typical specification is of the form:
|
|
95
|
+
```python
|
|
96
|
+
Params(
|
|
97
|
+
nn_params=True | False,
|
|
98
|
+
eq_params={
|
|
99
|
+
"alpha":True | False,
|
|
100
|
+
"beta":True | False,
|
|
101
|
+
...
|
|
102
|
+
}
|
|
103
|
+
)
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
109
|
+
Tell wrt which node of `Params` we will differentiate the
|
|
110
|
+
dynamic loss. To do so, the fields of `Params` contain True (if
|
|
111
|
+
differentiation) or False (if no differentiation).
|
|
112
|
+
observations : Params | ParamsDict | None, default=None
|
|
113
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
114
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
115
|
+
differentiation) or False (if no differentiation).
|
|
116
|
+
initial_condition : Params | ParamsDict | None, default=None
|
|
117
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
118
|
+
initial condition loss. To do so, the fields of Params contain True (if
|
|
119
|
+
differentiation) or False (if no differentiation).
|
|
120
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
121
|
+
The main Params object of the problem. It is required
|
|
122
|
+
if some terms are unspecified (None). This is because, jinns cannot
|
|
123
|
+
infer the content of `Params.eq_params`.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
127
|
+
observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
128
|
+
initial_condition: Params | ParamsDict | None = eqx.field(
|
|
129
|
+
kw_only=True, default=None
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
133
|
+
|
|
134
|
+
def __post_init__(self, params=None):
|
|
135
|
+
if self.dyn_loss is None:
|
|
136
|
+
try:
|
|
137
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
138
|
+
except AttributeError:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"self.dyn_loss is None, hence params should be " "passed"
|
|
141
|
+
)
|
|
142
|
+
if self.observations is None:
|
|
143
|
+
try:
|
|
144
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
145
|
+
except AttributeError:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"self.observations is None, hence params should be " "passed"
|
|
148
|
+
)
|
|
149
|
+
if self.initial_condition is None:
|
|
150
|
+
try:
|
|
151
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
152
|
+
except AttributeError:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"self.initial_condition is None, hence params should be " "passed"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def from_str(
|
|
159
|
+
cls,
|
|
160
|
+
params: Params | ParamsDict,
|
|
161
|
+
dyn_loss: (
|
|
162
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
163
|
+
) = "nn_params",
|
|
164
|
+
observations: (
|
|
165
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
166
|
+
) = "nn_params",
|
|
167
|
+
initial_condition: (
|
|
168
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
169
|
+
) = "nn_params",
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Construct the DerivativeKeysODE from strings. For each term of the
|
|
173
|
+
loss, specify whether to differentiate wrt the neural network
|
|
174
|
+
parameters, the equation parameters or both. The `Params` object, which
|
|
175
|
+
contains the actual array of parameters must be passed to
|
|
176
|
+
construct the fields with the appropriate PyTree structure.
|
|
177
|
+
|
|
178
|
+
!!! note
|
|
179
|
+
You can mix strings and `Params` if you need granularity.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
params
|
|
184
|
+
The actual Params or ParamsDict object of the problem.
|
|
185
|
+
dyn_loss
|
|
186
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
187
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
188
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
189
|
+
observations
|
|
190
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
191
|
+
`"both"` we will differentiate the observations. Default is
|
|
192
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
193
|
+
initial_condition
|
|
194
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
195
|
+
`"both"` we will differentiate the initial condition. Default is
|
|
196
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
197
|
+
"""
|
|
198
|
+
return DerivativeKeysODE(
|
|
199
|
+
dyn_loss=(
|
|
200
|
+
_get_masked_parameters(dyn_loss, params)
|
|
201
|
+
if isinstance(dyn_loss, str)
|
|
202
|
+
else dyn_loss
|
|
203
|
+
),
|
|
204
|
+
observations=(
|
|
205
|
+
_get_masked_parameters(observations, params)
|
|
206
|
+
if isinstance(observations, str)
|
|
207
|
+
else observations
|
|
208
|
+
),
|
|
209
|
+
initial_condition=(
|
|
210
|
+
_get_masked_parameters(initial_condition, params)
|
|
211
|
+
if isinstance(initial_condition, str)
|
|
212
|
+
else initial_condition
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class DerivativeKeysPDEStatio(eqx.Module):
|
|
218
|
+
"""
|
|
219
|
+
See [jinns.parameters.DerivativeKeysODE][].
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
224
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
225
|
+
dynamic loss. To do so, the fields of Params contain True (if
|
|
226
|
+
differentiation) or False (if no differentiation).
|
|
227
|
+
observations : Params | ParamsDict | None, default=None
|
|
228
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
229
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
230
|
+
differentiation) or False (if no differentiation).
|
|
231
|
+
boundary_loss : Params | ParamsDict | None, default=None
|
|
232
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
233
|
+
boundary loss. To do so, the fields of Params contain True (if
|
|
234
|
+
differentiation) or False (if no differentiation).
|
|
235
|
+
norm_loss : Params | ParamsDict | None, default=None
|
|
236
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
237
|
+
normalization loss. To do so, the fields of Params contain True (if
|
|
238
|
+
differentiation) or False (if no differentiation).
|
|
239
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
240
|
+
The main Params object of the problem. It is required
|
|
241
|
+
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
242
|
+
content of `Params.eq_params`.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
246
|
+
observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
247
|
+
boundary_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
248
|
+
norm_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
249
|
+
|
|
250
|
+
params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
251
|
+
|
|
252
|
+
def __post_init__(self, params=None):
|
|
253
|
+
if self.dyn_loss is None:
|
|
254
|
+
try:
|
|
255
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
256
|
+
except AttributeError:
|
|
257
|
+
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
258
|
+
if self.observations is None:
|
|
259
|
+
try:
|
|
260
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
261
|
+
except AttributeError:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"self.observations is None, hence params should be passed"
|
|
264
|
+
)
|
|
265
|
+
if self.boundary_loss is None:
|
|
266
|
+
try:
|
|
267
|
+
self.boundary_loss = _get_masked_parameters("nn_params", params)
|
|
268
|
+
except AttributeError:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"self.boundary_loss is None, hence params should be passed"
|
|
271
|
+
)
|
|
272
|
+
if self.norm_loss is None:
|
|
273
|
+
try:
|
|
274
|
+
self.norm_loss = _get_masked_parameters("nn_params", params)
|
|
275
|
+
except AttributeError:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
"self.norm_loss is None, hence params should be passed"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def from_str(
|
|
282
|
+
cls,
|
|
283
|
+
params: Params | ParamsDict,
|
|
284
|
+
dyn_loss: (
|
|
285
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
286
|
+
) = "nn_params",
|
|
287
|
+
observations: (
|
|
288
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
289
|
+
) = "nn_params",
|
|
290
|
+
boundary_loss: (
|
|
291
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
292
|
+
) = "nn_params",
|
|
293
|
+
norm_loss: (
|
|
294
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
295
|
+
) = "nn_params",
|
|
296
|
+
):
|
|
297
|
+
"""
|
|
298
|
+
See [jinns.parameters.DerivativeKeysODE.from_str][].
|
|
299
|
+
|
|
300
|
+
Parameters
|
|
301
|
+
----------
|
|
302
|
+
params
|
|
303
|
+
The actual Param or ParamsDict object of the problem.
|
|
304
|
+
dyn_loss
|
|
305
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
306
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
307
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
308
|
+
observations
|
|
309
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
310
|
+
`"both"` we will differentiate the observations. Default is
|
|
311
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
312
|
+
boundary_loss
|
|
313
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
314
|
+
`"both"` we will differentiate the boundary loss. Default is
|
|
315
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
316
|
+
norm_loss
|
|
317
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
318
|
+
`"both"` we will differentiate the normalization loss. Default is
|
|
319
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
320
|
+
"""
|
|
321
|
+
return DerivativeKeysPDEStatio(
|
|
322
|
+
dyn_loss=(
|
|
323
|
+
_get_masked_parameters(dyn_loss, params)
|
|
324
|
+
if isinstance(dyn_loss, str)
|
|
325
|
+
else dyn_loss
|
|
326
|
+
),
|
|
327
|
+
observations=(
|
|
328
|
+
_get_masked_parameters(observations, params)
|
|
329
|
+
if isinstance(observations, str)
|
|
330
|
+
else observations
|
|
331
|
+
),
|
|
332
|
+
boundary_loss=(
|
|
333
|
+
_get_masked_parameters(boundary_loss, params)
|
|
334
|
+
if isinstance(boundary_loss, str)
|
|
335
|
+
else boundary_loss
|
|
336
|
+
),
|
|
337
|
+
norm_loss=(
|
|
338
|
+
_get_masked_parameters(norm_loss, params)
|
|
339
|
+
if isinstance(norm_loss, str)
|
|
340
|
+
else norm_loss
|
|
341
|
+
),
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
346
|
+
"""
|
|
347
|
+
See [jinns.parameters.DerivativeKeysODE][].
|
|
348
|
+
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
352
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
353
|
+
dynamic loss. To do so, the fields of Params contain True (if
|
|
354
|
+
differentiation) or False (if no differentiation).
|
|
355
|
+
observations : Params | ParamsDict | None, default=None
|
|
356
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
357
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
358
|
+
differentiation) or False (if no differentiation).
|
|
359
|
+
boundary_loss : Params | ParamsDict | None, default=None
|
|
360
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
361
|
+
boundary loss. To do so, the fields of Params contain True (if
|
|
362
|
+
differentiation) or False (if no differentiation).
|
|
363
|
+
norm_loss : Params | ParamsDict | None, default=None
|
|
364
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
365
|
+
normalization loss. To do so, the fields of Params contain True (if
|
|
366
|
+
differentiation) or False (if no differentiation).
|
|
367
|
+
initial_condition : Params | ParamsDict | None, default=None
|
|
368
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
369
|
+
initial_condition loss. To do so, the fields of Params contain True (if
|
|
370
|
+
differentiation) or False (if no differentiation).
|
|
371
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
372
|
+
The main Params object of the problem. It is required
|
|
373
|
+
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
374
|
+
content of `Params.eq_params`.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
initial_condition: Params | ParamsDict | None = eqx.field(
|
|
378
|
+
kw_only=True, default=None
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def __post_init__(self, params=None):
|
|
382
|
+
super().__post_init__(params=params)
|
|
383
|
+
if self.initial_condition is None:
|
|
384
|
+
try:
|
|
385
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
386
|
+
except AttributeError:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
"self.initial_condition is None, hence params should be passed"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def from_str(
|
|
393
|
+
cls,
|
|
394
|
+
params: Params | ParamsDict,
|
|
395
|
+
dyn_loss: (
|
|
396
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
397
|
+
) = "nn_params",
|
|
398
|
+
observations: (
|
|
399
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
400
|
+
) = "nn_params",
|
|
401
|
+
boundary_loss: (
|
|
402
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
403
|
+
) = "nn_params",
|
|
404
|
+
norm_loss: (
|
|
405
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
406
|
+
) = "nn_params",
|
|
407
|
+
initial_condition: (
|
|
408
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
409
|
+
) = "nn_params",
|
|
410
|
+
):
|
|
411
|
+
"""
|
|
412
|
+
See [jinns.parameters.DerivativeKeysODE.from_str][].
|
|
413
|
+
|
|
414
|
+
Parameters
|
|
415
|
+
----------
|
|
416
|
+
params
|
|
417
|
+
The actual Params | ParamsDict object of the problem.
|
|
418
|
+
dyn_loss
|
|
419
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
420
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
421
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
422
|
+
observations
|
|
423
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
424
|
+
`"both"` we will differentiate the observations. Default is
|
|
425
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
426
|
+
boundary_loss
|
|
427
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
428
|
+
`"both"` we will differentiate the boundary loss. Default is
|
|
429
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
430
|
+
norm_loss
|
|
431
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
432
|
+
`"both"` we will differentiate the normalization loss. Default is
|
|
433
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
434
|
+
initial_condition
|
|
435
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
436
|
+
`"both"` we will differentiate the initial_condition loss. Default is
|
|
437
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
438
|
+
"""
|
|
439
|
+
return DerivativeKeysPDENonStatio(
|
|
440
|
+
dyn_loss=(
|
|
441
|
+
_get_masked_parameters(dyn_loss, params)
|
|
442
|
+
if isinstance(dyn_loss, str)
|
|
443
|
+
else dyn_loss
|
|
444
|
+
),
|
|
445
|
+
observations=(
|
|
446
|
+
_get_masked_parameters(observations, params)
|
|
447
|
+
if isinstance(observations, str)
|
|
448
|
+
else observations
|
|
449
|
+
),
|
|
450
|
+
boundary_loss=(
|
|
451
|
+
_get_masked_parameters(boundary_loss, params)
|
|
452
|
+
if isinstance(boundary_loss, str)
|
|
453
|
+
else boundary_loss
|
|
454
|
+
),
|
|
455
|
+
norm_loss=(
|
|
456
|
+
_get_masked_parameters(norm_loss, params)
|
|
457
|
+
if isinstance(norm_loss, str)
|
|
458
|
+
else norm_loss
|
|
459
|
+
),
|
|
460
|
+
initial_condition=(
|
|
461
|
+
_get_masked_parameters(initial_condition, params)
|
|
462
|
+
if isinstance(initial_condition, str)
|
|
463
|
+
else initial_condition
|
|
464
|
+
),
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _set_derivatives(params, derivative_keys):
|
|
469
|
+
"""
|
|
470
|
+
We construct an eqx.Module with the fields of derivative_keys, each field
|
|
471
|
+
has a copy of the params with appropriate derivatives set
|
|
472
|
+
"""
|
|
473
|
+
|
|
474
|
+
def _set_derivatives_ParamsDict(params_, derivative_mask):
|
|
475
|
+
"""
|
|
476
|
+
The next lines put a stop_gradient around the fields that do not
|
|
477
|
+
differentiate the loss term
|
|
478
|
+
**Note:** **No granularity inside `ParamsDict.nn_params` is currently
|
|
479
|
+
supported.**
|
|
480
|
+
This means a typical Params specification is of the form:
|
|
481
|
+
`ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
|
|
482
|
+
"beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
|
|
483
|
+
"""
|
|
484
|
+
# a ParamsDict object is reconstructed by hand since we do not want to
|
|
485
|
+
# traverse nn_params, for now...
|
|
486
|
+
return ParamsDict(
|
|
487
|
+
nn_params=jax.lax.cond(
|
|
488
|
+
derivative_mask.nn_params,
|
|
489
|
+
lambda p: p,
|
|
490
|
+
jax.lax.stop_gradient,
|
|
491
|
+
params_.nn_params,
|
|
492
|
+
),
|
|
493
|
+
eq_params=jax.tree.map(
|
|
494
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
495
|
+
params_.eq_params,
|
|
496
|
+
derivative_mask.eq_params,
|
|
497
|
+
),
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
def _set_derivatives_(params_, derivative_mask):
|
|
501
|
+
"""
|
|
502
|
+
The next lines put a stop_gradient around the fields that do not
|
|
503
|
+
differentiate the loss term
|
|
504
|
+
**Note:** **No granularity inside `Params.nn_params` is currently
|
|
505
|
+
supported.**
|
|
506
|
+
This means a typical Params specification is of the form:
|
|
507
|
+
`Params(nn_params=True | False, eq_params={"alpha":True | False,
|
|
508
|
+
"beta":True | False})`.
|
|
509
|
+
"""
|
|
510
|
+
return jax.tree.map(
|
|
511
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
512
|
+
params_,
|
|
513
|
+
derivative_mask,
|
|
514
|
+
is_leaf=lambda x: isinstance(x, eqx.Module)
|
|
515
|
+
and not isinstance(x, Params), # do not travers nn_params, more
|
|
516
|
+
# granularity could be imagined here, in the future
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
if isinstance(params, ParamsDict):
|
|
520
|
+
return _set_derivatives_ParamsDict(params, derivative_keys)
|
|
521
|
+
return _set_derivatives_(params, derivative_keys)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Formalize the data structure for the parameters
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
from typing import Dict
|
|
8
|
+
from jaxtyping import Array, PyTree
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Params(eqx.Module):
|
|
12
|
+
"""
|
|
13
|
+
The equinox module for the parameters
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
nn_params : Pytree
|
|
18
|
+
A PyTree of the non-static part of the PINN eqx.Module, i.e., the
|
|
19
|
+
parameters of the PINN
|
|
20
|
+
eq_params : Dict[str, Array]
|
|
21
|
+
A dictionary of the equation parameters. Keys are the parameter name,
|
|
22
|
+
values are their corresponding value
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
nn_params: PyTree = eqx.field(kw_only=True, default=None)
|
|
26
|
+
eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ParamsDict(eqx.Module):
|
|
30
|
+
"""
|
|
31
|
+
The equinox module for a dictionnary of parameters with different keys
|
|
32
|
+
corresponding to different equations.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
nn_params : Dict[str, PyTree]
|
|
37
|
+
The neural network's parameters. Most of the time, it will be the
|
|
38
|
+
Array part of an `eqx.Module` obtained by
|
|
39
|
+
`eqx.partition(module, eqx.is_inexact_array)`.
|
|
40
|
+
eq_params : Dict[str, Array]
|
|
41
|
+
A dictionary of the equation parameters. Dict keys are the parameter name as defined your custom loss.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
nn_params: Dict[str, PyTree] = eqx.field(kw_only=True, default=None)
|
|
45
|
+
eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
|
|
46
|
+
|
|
47
|
+
def extract_params(self, nn_key: str) -> Params:
|
|
48
|
+
"""
|
|
49
|
+
Extract the corresponding `nn_params` and `eq_params` for `nn_key` and
|
|
50
|
+
return them in the form of a `Params` object.
|
|
51
|
+
"""
|
|
52
|
+
try:
|
|
53
|
+
return Params(
|
|
54
|
+
nn_params=self.nn_params[nn_key],
|
|
55
|
+
eq_params=self.eq_params[nn_key],
|
|
56
|
+
)
|
|
57
|
+
except (KeyError, IndexError) as e:
|
|
58
|
+
return Params(
|
|
59
|
+
nn_params=self.nn_params[nn_key],
|
|
60
|
+
eq_params=self.eq_params,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _update_eq_params_dict(
|
|
65
|
+
params: Params, param_batch_dict: Dict[str, Array]
|
|
66
|
+
) -> Params:
|
|
67
|
+
"""
|
|
68
|
+
Update params.eq_params with a batch of eq_params for given key(s)
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
# artificially "complete" `param_batch_dict` with None to match `params`
|
|
72
|
+
# PyTree structure
|
|
73
|
+
param_batch_dict_ = param_batch_dict | {
|
|
74
|
+
k: None for k in set(params.eq_params.keys()) - set(param_batch_dict.keys())
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Replace at non None leafs
|
|
78
|
+
params = eqx.tree_at(
|
|
79
|
+
lambda p: p.eq_params,
|
|
80
|
+
params,
|
|
81
|
+
jax.tree_util.tree_map(
|
|
82
|
+
lambda p, q: q if q is not None else p,
|
|
83
|
+
params.eq_params,
|
|
84
|
+
param_batch_dict_,
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return params
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_vmap_in_axes_params(
|
|
92
|
+
eq_params_batch_dict: Dict[str, Array], params: Params | ParamsDict
|
|
93
|
+
) -> tuple[Params]:
|
|
94
|
+
"""
|
|
95
|
+
Return the input vmap axes when there is batch(es) of parameters to vmap
|
|
96
|
+
over. The latter are designated by keys in eq_params_batch_dict.
|
|
97
|
+
If eq_params_batch_dict is None (i.e. no additional parameter batch), we
|
|
98
|
+
return (None,).
|
|
99
|
+
"""
|
|
100
|
+
if eq_params_batch_dict is None:
|
|
101
|
+
return (None,)
|
|
102
|
+
# We use pytree indexing of vmapped axes and vmap on axis
|
|
103
|
+
# 0 of the eq_parameters for which we have a batch
|
|
104
|
+
# this is for a fine-grained vmaping
|
|
105
|
+
# scheme over the params
|
|
106
|
+
vmap_in_axes_params = (
|
|
107
|
+
type(params)(
|
|
108
|
+
nn_params=None,
|
|
109
|
+
eq_params={
|
|
110
|
+
k: (0 if k in eq_params_batch_dict.keys() else None)
|
|
111
|
+
for k in params.eq_params.keys()
|
|
112
|
+
},
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
return vmap_in_axes_params
|