jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -2,50 +2,468 @@
|
|
|
2
2
|
Formalize the data structure for the derivative keys
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from functools import partial
|
|
6
|
+
from dataclasses import fields, InitVar
|
|
6
7
|
from typing import Literal
|
|
7
8
|
import jax
|
|
8
9
|
import equinox as eqx
|
|
9
10
|
|
|
10
|
-
from jinns.parameters._params import Params
|
|
11
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_masked_parameters(
|
|
15
|
+
derivative_mask_str: str, params: Params | ParamsDict
|
|
16
|
+
) -> Params | ParamsDict:
|
|
17
|
+
"""
|
|
18
|
+
Creates the Params object with True values where we want to differentiate
|
|
19
|
+
"""
|
|
20
|
+
if isinstance(params, Params):
|
|
21
|
+
# start with a params object with True everywhere. We will update to False
|
|
22
|
+
# for parameters wrt which we do want not to differentiate the loss
|
|
23
|
+
diff_params = jax.tree.map(
|
|
24
|
+
lambda x: True,
|
|
25
|
+
params,
|
|
26
|
+
is_leaf=lambda x: isinstance(x, eqx.Module)
|
|
27
|
+
and not isinstance(x, Params), # do not travers nn_params, more
|
|
28
|
+
# granularity could be imagined here, in the future
|
|
29
|
+
)
|
|
30
|
+
if derivative_mask_str == "both":
|
|
31
|
+
return diff_params
|
|
32
|
+
if derivative_mask_str == "eq_params":
|
|
33
|
+
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
34
|
+
if derivative_mask_str == "nn_params":
|
|
35
|
+
return eqx.tree_at(
|
|
36
|
+
lambda p: p.eq_params,
|
|
37
|
+
diff_params,
|
|
38
|
+
jax.tree.map(lambda x: False, params.eq_params),
|
|
39
|
+
)
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"Bad value for DerivativeKeys. Got "
|
|
42
|
+
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
43
|
+
' "eq_params"'
|
|
44
|
+
)
|
|
45
|
+
elif isinstance(params, ParamsDict):
|
|
46
|
+
# do not travers nn_params, more
|
|
47
|
+
# granularity could be imagined here, in the future
|
|
48
|
+
diff_params = ParamsDict(
|
|
49
|
+
nn_params=True, eq_params=jax.tree.map(lambda x: True, params.eq_params)
|
|
50
|
+
)
|
|
51
|
+
if derivative_mask_str == "both":
|
|
52
|
+
return diff_params
|
|
53
|
+
if derivative_mask_str == "eq_params":
|
|
54
|
+
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
55
|
+
if derivative_mask_str == "nn_params":
|
|
56
|
+
return eqx.tree_at(
|
|
57
|
+
lambda p: p.eq_params,
|
|
58
|
+
diff_params,
|
|
59
|
+
jax.tree.map(lambda x: False, params.eq_params),
|
|
60
|
+
)
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"Bad value for DerivativeKeys. Got "
|
|
63
|
+
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
64
|
+
' "eq_params"'
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Bad value for params. Got {type(params)}, expected Params "
|
|
70
|
+
" or ParamsDict"
|
|
71
|
+
)
|
|
11
72
|
|
|
12
73
|
|
|
13
74
|
class DerivativeKeysODE(eqx.Module):
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
75
|
+
"""
|
|
76
|
+
A class that specifies with repect to which parameter(s) each term of the
|
|
77
|
+
loss is differentiated. For example, you can specify that the
|
|
78
|
+
[`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
|
|
79
|
+
respect to the neural network parameters *and* the equation parameters, or only some of them.
|
|
80
|
+
|
|
81
|
+
To do so, user can either use strings or a `Params` object
|
|
82
|
+
with PyTree structure matching the parameters of the problem at
|
|
83
|
+
hand, and booleans indicating if gradient is to be taken or not. Internally,
|
|
84
|
+
a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
|
|
85
|
+
computing each loss term.
|
|
86
|
+
|
|
87
|
+
!!! note
|
|
88
|
+
|
|
89
|
+
1. For unspecified loss term, the default is to differentiate with
|
|
90
|
+
respect to `"nn_params"` only.
|
|
91
|
+
2. No granularity inside `Params.nn_params` is currently supported.
|
|
92
|
+
3. Note that the main Params or ParamsDict object of the problem is mandatory if initialization via `from_str()`.
|
|
93
|
+
|
|
94
|
+
A typical specification is of the form:
|
|
95
|
+
```python
|
|
96
|
+
Params(
|
|
97
|
+
nn_params=True | False,
|
|
98
|
+
eq_params={
|
|
99
|
+
"alpha":True | False,
|
|
100
|
+
"beta":True | False,
|
|
101
|
+
...
|
|
102
|
+
}
|
|
21
103
|
)
|
|
22
|
-
|
|
23
|
-
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
109
|
+
Tell wrt which node of `Params` we will differentiate the
|
|
110
|
+
dynamic loss. To do so, the fields of `Params` contain True (if
|
|
111
|
+
differentiation) or False (if no differentiation).
|
|
112
|
+
observations : Params | ParamsDict | None, default=None
|
|
113
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
114
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
115
|
+
differentiation) or False (if no differentiation).
|
|
116
|
+
initial_condition : Params | ParamsDict | None, default=None
|
|
117
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
118
|
+
initial condition loss. To do so, the fields of Params contain True (if
|
|
119
|
+
differentiation) or False (if no differentiation).
|
|
120
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
121
|
+
The main Params object of the problem. It is required
|
|
122
|
+
if some terms are unspecified (None). This is because, jinns cannot
|
|
123
|
+
infer the content of `Params.eq_params`.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
127
|
+
observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
128
|
+
initial_condition: Params | ParamsDict | None = eqx.field(
|
|
129
|
+
kw_only=True, default=None
|
|
24
130
|
)
|
|
25
131
|
|
|
132
|
+
params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
133
|
+
|
|
134
|
+
def __post_init__(self, params=None):
|
|
135
|
+
if self.dyn_loss is None:
|
|
136
|
+
try:
|
|
137
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
138
|
+
except AttributeError:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"self.dyn_loss is None, hence params should be " "passed"
|
|
141
|
+
)
|
|
142
|
+
if self.observations is None:
|
|
143
|
+
try:
|
|
144
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
145
|
+
except AttributeError:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"self.observations is None, hence params should be " "passed"
|
|
148
|
+
)
|
|
149
|
+
if self.initial_condition is None:
|
|
150
|
+
try:
|
|
151
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
152
|
+
except AttributeError:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"self.initial_condition is None, hence params should be " "passed"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def from_str(
|
|
159
|
+
cls,
|
|
160
|
+
params: Params | ParamsDict,
|
|
161
|
+
dyn_loss: (
|
|
162
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
163
|
+
) = "nn_params",
|
|
164
|
+
observations: (
|
|
165
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
166
|
+
) = "nn_params",
|
|
167
|
+
initial_condition: (
|
|
168
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
169
|
+
) = "nn_params",
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Construct the DerivativeKeysODE from strings. For each term of the
|
|
173
|
+
loss, specify whether to differentiate wrt the neural network
|
|
174
|
+
parameters, the equation parameters or both. The `Params` object, which
|
|
175
|
+
contains the actual array of parameters must be passed to
|
|
176
|
+
construct the fields with the appropriate PyTree structure.
|
|
177
|
+
|
|
178
|
+
!!! note
|
|
179
|
+
You can mix strings and `Params` if you need granularity.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
params
|
|
184
|
+
The actual Params or ParamsDict object of the problem.
|
|
185
|
+
dyn_loss
|
|
186
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
187
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
188
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
189
|
+
observations
|
|
190
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
191
|
+
`"both"` we will differentiate the observations. Default is
|
|
192
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
193
|
+
initial_condition
|
|
194
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
195
|
+
`"both"` we will differentiate the initial condition. Default is
|
|
196
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
197
|
+
"""
|
|
198
|
+
return DerivativeKeysODE(
|
|
199
|
+
dyn_loss=(
|
|
200
|
+
_get_masked_parameters(dyn_loss, params)
|
|
201
|
+
if isinstance(dyn_loss, str)
|
|
202
|
+
else dyn_loss
|
|
203
|
+
),
|
|
204
|
+
observations=(
|
|
205
|
+
_get_masked_parameters(observations, params)
|
|
206
|
+
if isinstance(observations, str)
|
|
207
|
+
else observations
|
|
208
|
+
),
|
|
209
|
+
initial_condition=(
|
|
210
|
+
_get_masked_parameters(initial_condition, params)
|
|
211
|
+
if isinstance(initial_condition, str)
|
|
212
|
+
else initial_condition
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
|
|
26
216
|
|
|
27
217
|
class DerivativeKeysPDEStatio(eqx.Module):
|
|
218
|
+
"""
|
|
219
|
+
See [jinns.parameters.DerivativeKeysODE][].
|
|
28
220
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
224
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
225
|
+
dynamic loss. To do so, the fields of Params contain True (if
|
|
226
|
+
differentiation) or False (if no differentiation).
|
|
227
|
+
observations : Params | ParamsDict | None, default=None
|
|
228
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
229
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
230
|
+
differentiation) or False (if no differentiation).
|
|
231
|
+
boundary_loss : Params | ParamsDict | None, default=None
|
|
232
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
233
|
+
boundary loss. To do so, the fields of Params contain True (if
|
|
234
|
+
differentiation) or False (if no differentiation).
|
|
235
|
+
norm_loss : Params | ParamsDict | None, default=None
|
|
236
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
237
|
+
normalization loss. To do so, the fields of Params contain True (if
|
|
238
|
+
differentiation) or False (if no differentiation).
|
|
239
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
240
|
+
The main Params object of the problem. It is required
|
|
241
|
+
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
242
|
+
content of `Params.eq_params`.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
246
|
+
observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
247
|
+
boundary_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
248
|
+
norm_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
249
|
+
|
|
250
|
+
params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
251
|
+
|
|
252
|
+
def __post_init__(self, params=None):
|
|
253
|
+
if self.dyn_loss is None:
|
|
254
|
+
try:
|
|
255
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
256
|
+
except AttributeError:
|
|
257
|
+
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
258
|
+
if self.observations is None:
|
|
259
|
+
try:
|
|
260
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
261
|
+
except AttributeError:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"self.observations is None, hence params should be passed"
|
|
264
|
+
)
|
|
265
|
+
if self.boundary_loss is None:
|
|
266
|
+
try:
|
|
267
|
+
self.boundary_loss = _get_masked_parameters("nn_params", params)
|
|
268
|
+
except AttributeError:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"self.boundary_loss is None, hence params should be passed"
|
|
271
|
+
)
|
|
272
|
+
if self.norm_loss is None:
|
|
273
|
+
try:
|
|
274
|
+
self.norm_loss = _get_masked_parameters("nn_params", params)
|
|
275
|
+
except AttributeError:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
"self.norm_loss is None, hence params should be passed"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def from_str(
|
|
282
|
+
cls,
|
|
283
|
+
params: Params | ParamsDict,
|
|
284
|
+
dyn_loss: (
|
|
285
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
286
|
+
) = "nn_params",
|
|
287
|
+
observations: (
|
|
288
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
289
|
+
) = "nn_params",
|
|
290
|
+
boundary_loss: (
|
|
291
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
292
|
+
) = "nn_params",
|
|
293
|
+
norm_loss: (
|
|
294
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
295
|
+
) = "nn_params",
|
|
296
|
+
):
|
|
297
|
+
"""
|
|
298
|
+
See [jinns.parameters.DerivativeKeysODE.from_str][].
|
|
299
|
+
|
|
300
|
+
Parameters
|
|
301
|
+
----------
|
|
302
|
+
params
|
|
303
|
+
The actual Param or ParamsDict object of the problem.
|
|
304
|
+
dyn_loss
|
|
305
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
306
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
307
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
308
|
+
observations
|
|
309
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
310
|
+
`"both"` we will differentiate the observations. Default is
|
|
311
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
312
|
+
boundary_loss
|
|
313
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
314
|
+
`"both"` we will differentiate the boundary loss. Default is
|
|
315
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
316
|
+
norm_loss
|
|
317
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
318
|
+
`"both"` we will differentiate the normalization loss. Default is
|
|
319
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
320
|
+
"""
|
|
321
|
+
return DerivativeKeysPDEStatio(
|
|
322
|
+
dyn_loss=(
|
|
323
|
+
_get_masked_parameters(dyn_loss, params)
|
|
324
|
+
if isinstance(dyn_loss, str)
|
|
325
|
+
else dyn_loss
|
|
326
|
+
),
|
|
327
|
+
observations=(
|
|
328
|
+
_get_masked_parameters(observations, params)
|
|
329
|
+
if isinstance(observations, str)
|
|
330
|
+
else observations
|
|
331
|
+
),
|
|
332
|
+
boundary_loss=(
|
|
333
|
+
_get_masked_parameters(boundary_loss, params)
|
|
334
|
+
if isinstance(boundary_loss, str)
|
|
335
|
+
else boundary_loss
|
|
336
|
+
),
|
|
337
|
+
norm_loss=(
|
|
338
|
+
_get_masked_parameters(norm_loss, params)
|
|
339
|
+
if isinstance(norm_loss, str)
|
|
340
|
+
else norm_loss
|
|
341
|
+
),
|
|
342
|
+
)
|
|
41
343
|
|
|
42
344
|
|
|
43
345
|
class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
346
|
+
"""
|
|
347
|
+
See [jinns.parameters.DerivativeKeysODE][].
|
|
44
348
|
|
|
45
|
-
|
|
46
|
-
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
352
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
353
|
+
dynamic loss. To do so, the fields of Params contain True (if
|
|
354
|
+
differentiation) or False (if no differentiation).
|
|
355
|
+
observations : Params | ParamsDict | None, default=None
|
|
356
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
357
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
358
|
+
differentiation) or False (if no differentiation).
|
|
359
|
+
boundary_loss : Params | ParamsDict | None, default=None
|
|
360
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
361
|
+
boundary loss. To do so, the fields of Params contain True (if
|
|
362
|
+
differentiation) or False (if no differentiation).
|
|
363
|
+
norm_loss : Params | ParamsDict | None, default=None
|
|
364
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
365
|
+
normalization loss. To do so, the fields of Params contain True (if
|
|
366
|
+
differentiation) or False (if no differentiation).
|
|
367
|
+
initial_condition : Params | ParamsDict | None, default=None
|
|
368
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
369
|
+
initial_condition loss. To do so, the fields of Params contain True (if
|
|
370
|
+
differentiation) or False (if no differentiation).
|
|
371
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
372
|
+
The main Params object of the problem. It is required
|
|
373
|
+
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
374
|
+
content of `Params.eq_params`.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
initial_condition: Params | ParamsDict | None = eqx.field(
|
|
378
|
+
kw_only=True, default=None
|
|
47
379
|
)
|
|
48
380
|
|
|
381
|
+
def __post_init__(self, params=None):
|
|
382
|
+
super().__post_init__(params=params)
|
|
383
|
+
if self.initial_condition is None:
|
|
384
|
+
try:
|
|
385
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
386
|
+
except AttributeError:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
"self.initial_condition is None, hence params should be passed"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def from_str(
|
|
393
|
+
cls,
|
|
394
|
+
params: Params | ParamsDict,
|
|
395
|
+
dyn_loss: (
|
|
396
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
397
|
+
) = "nn_params",
|
|
398
|
+
observations: (
|
|
399
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
400
|
+
) = "nn_params",
|
|
401
|
+
boundary_loss: (
|
|
402
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
403
|
+
) = "nn_params",
|
|
404
|
+
norm_loss: (
|
|
405
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
406
|
+
) = "nn_params",
|
|
407
|
+
initial_condition: (
|
|
408
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
409
|
+
) = "nn_params",
|
|
410
|
+
):
|
|
411
|
+
"""
|
|
412
|
+
See [jinns.parameters.DerivativeKeysODE.from_str][].
|
|
413
|
+
|
|
414
|
+
Parameters
|
|
415
|
+
----------
|
|
416
|
+
params
|
|
417
|
+
The actual Params | ParamsDict object of the problem.
|
|
418
|
+
dyn_loss
|
|
419
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
420
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
421
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
422
|
+
observations
|
|
423
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
424
|
+
`"both"` we will differentiate the observations. Default is
|
|
425
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
426
|
+
boundary_loss
|
|
427
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
428
|
+
`"both"` we will differentiate the boundary loss. Default is
|
|
429
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
430
|
+
norm_loss
|
|
431
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
432
|
+
`"both"` we will differentiate the normalization loss. Default is
|
|
433
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
434
|
+
initial_condition
|
|
435
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
436
|
+
`"both"` we will differentiate the initial_condition loss. Default is
|
|
437
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
438
|
+
"""
|
|
439
|
+
return DerivativeKeysPDENonStatio(
|
|
440
|
+
dyn_loss=(
|
|
441
|
+
_get_masked_parameters(dyn_loss, params)
|
|
442
|
+
if isinstance(dyn_loss, str)
|
|
443
|
+
else dyn_loss
|
|
444
|
+
),
|
|
445
|
+
observations=(
|
|
446
|
+
_get_masked_parameters(observations, params)
|
|
447
|
+
if isinstance(observations, str)
|
|
448
|
+
else observations
|
|
449
|
+
),
|
|
450
|
+
boundary_loss=(
|
|
451
|
+
_get_masked_parameters(boundary_loss, params)
|
|
452
|
+
if isinstance(boundary_loss, str)
|
|
453
|
+
else boundary_loss
|
|
454
|
+
),
|
|
455
|
+
norm_loss=(
|
|
456
|
+
_get_masked_parameters(norm_loss, params)
|
|
457
|
+
if isinstance(norm_loss, str)
|
|
458
|
+
else norm_loss
|
|
459
|
+
),
|
|
460
|
+
initial_condition=(
|
|
461
|
+
_get_masked_parameters(initial_condition, params)
|
|
462
|
+
if isinstance(initial_condition, str)
|
|
463
|
+
else initial_condition
|
|
464
|
+
),
|
|
465
|
+
)
|
|
466
|
+
|
|
49
467
|
|
|
50
468
|
def _set_derivatives(params, derivative_keys):
|
|
51
469
|
"""
|
|
@@ -53,42 +471,51 @@ def _set_derivatives(params, derivative_keys):
|
|
|
53
471
|
has a copy of the params with appropriate derivatives set
|
|
54
472
|
"""
|
|
55
473
|
|
|
56
|
-
def
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
474
|
+
def _set_derivatives_ParamsDict(params_, derivative_mask):
|
|
475
|
+
"""
|
|
476
|
+
The next lines put a stop_gradient around the fields that do not
|
|
477
|
+
differentiate the loss term
|
|
478
|
+
**Note:** **No granularity inside `ParamsDict.nn_params` is currently
|
|
479
|
+
supported.**
|
|
480
|
+
This means a typical Params specification is of the form:
|
|
481
|
+
`ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
|
|
482
|
+
"beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
|
|
483
|
+
"""
|
|
484
|
+
# a ParamsDict object is reconstructed by hand since we do not want to
|
|
485
|
+
# traverse nn_params, for now...
|
|
486
|
+
return ParamsDict(
|
|
487
|
+
nn_params=jax.lax.cond(
|
|
488
|
+
derivative_mask.nn_params,
|
|
489
|
+
lambda p: p,
|
|
490
|
+
jax.lax.stop_gradient,
|
|
491
|
+
params_.nn_params,
|
|
492
|
+
),
|
|
493
|
+
eq_params=jax.tree.map(
|
|
494
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
495
|
+
params_.eq_params,
|
|
496
|
+
derivative_mask.eq_params,
|
|
67
497
|
),
|
|
68
|
-
params,
|
|
69
|
-
replace_fn=jax.lax.stop_gradient,
|
|
70
498
|
)
|
|
71
499
|
|
|
72
|
-
def
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
|
|
500
|
+
def _set_derivatives_(params_, derivative_mask):
|
|
501
|
+
"""
|
|
502
|
+
The next lines put a stop_gradient around the fields that do not
|
|
503
|
+
differentiate the loss term
|
|
504
|
+
**Note:** **No granularity inside `Params.nn_params` is currently
|
|
505
|
+
supported.**
|
|
506
|
+
This means a typical Params specification is of the form:
|
|
507
|
+
`Params(nn_params=True | False, eq_params={"alpha":True | False,
|
|
508
|
+
"beta":True | False})`.
|
|
509
|
+
"""
|
|
510
|
+
return jax.tree.map(
|
|
511
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
512
|
+
params_,
|
|
513
|
+
derivative_mask,
|
|
514
|
+
is_leaf=lambda x: isinstance(x, eqx.Module)
|
|
515
|
+
and not isinstance(x, Params), # do not travers nn_params, more
|
|
516
|
+
# granularity could be imagined here, in the future
|
|
517
|
+
)
|
|
90
518
|
|
|
91
|
-
if
|
|
92
|
-
return
|
|
93
|
-
|
|
94
|
-
return _set_derivatives_dict(derivative_keys)
|
|
519
|
+
if isinstance(params, ParamsDict):
|
|
520
|
+
return _set_derivatives_ParamsDict(params, derivative_keys)
|
|
521
|
+
return _set_derivatives_(params, derivative_keys)
|