jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -2,73 +2,45 @@
|
|
|
2
2
|
Formalize the data structure for the derivative keys
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from dataclasses import fields, InitVar
|
|
5
|
+
from dataclasses import InitVar
|
|
7
6
|
from typing import Literal
|
|
7
|
+
from jaxtyping import Array
|
|
8
8
|
import jax
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
|
|
11
|
-
from jinns.parameters._params import Params
|
|
11
|
+
from jinns.parameters._params import Params
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def _get_masked_parameters(
|
|
15
|
-
derivative_mask_str: str, params: Params
|
|
16
|
-
) -> Params
|
|
15
|
+
derivative_mask_str: str, params: Params[Array]
|
|
16
|
+
) -> Params[bool]:
|
|
17
17
|
"""
|
|
18
18
|
Creates the Params object with True values where we want to differentiate
|
|
19
19
|
"""
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
and not isinstance(x, Params), # do not travers nn_params, more
|
|
28
|
-
# granularity could be imagined here, in the future
|
|
29
|
-
)
|
|
30
|
-
if derivative_mask_str == "both":
|
|
31
|
-
return diff_params
|
|
32
|
-
if derivative_mask_str == "eq_params":
|
|
33
|
-
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
34
|
-
if derivative_mask_str == "nn_params":
|
|
35
|
-
return eqx.tree_at(
|
|
36
|
-
lambda p: p.eq_params,
|
|
37
|
-
diff_params,
|
|
38
|
-
jax.tree.map(lambda x: False, params.eq_params),
|
|
39
|
-
)
|
|
40
|
-
raise ValueError(
|
|
41
|
-
"Bad value for DerivativeKeys. Got "
|
|
42
|
-
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
43
|
-
' "eq_params"'
|
|
44
|
-
)
|
|
45
|
-
elif isinstance(params, ParamsDict):
|
|
46
|
-
# do not travers nn_params, more
|
|
20
|
+
# start with a params object with True everywhere. We will update to False
|
|
21
|
+
# for parameters wrt which we do want not to differentiate the loss
|
|
22
|
+
diff_params = jax.tree.map(
|
|
23
|
+
lambda x: True,
|
|
24
|
+
params,
|
|
25
|
+
is_leaf=lambda x: isinstance(x, eqx.Module)
|
|
26
|
+
and not isinstance(x, Params), # do not travers nn_params, more
|
|
47
27
|
# granularity could be imagined here, in the future
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
diff_params,
|
|
59
|
-
jax.tree.map(lambda x: False, params.eq_params),
|
|
60
|
-
)
|
|
61
|
-
raise ValueError(
|
|
62
|
-
"Bad value for DerivativeKeys. Got "
|
|
63
|
-
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
64
|
-
' "eq_params"'
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
else:
|
|
68
|
-
raise ValueError(
|
|
69
|
-
f"Bad value for params. Got {type(params)}, expected Params "
|
|
70
|
-
" or ParamsDict"
|
|
28
|
+
)
|
|
29
|
+
if derivative_mask_str == "both":
|
|
30
|
+
return diff_params
|
|
31
|
+
if derivative_mask_str == "eq_params":
|
|
32
|
+
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
33
|
+
if derivative_mask_str == "nn_params":
|
|
34
|
+
return eqx.tree_at(
|
|
35
|
+
lambda p: p.eq_params,
|
|
36
|
+
diff_params,
|
|
37
|
+
jax.tree.map(lambda x: False, params.eq_params),
|
|
71
38
|
)
|
|
39
|
+
raise ValueError(
|
|
40
|
+
"Bad value for DerivativeKeys. Got "
|
|
41
|
+
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
42
|
+
' "eq_params"'
|
|
43
|
+
)
|
|
72
44
|
|
|
73
45
|
|
|
74
46
|
class DerivativeKeysODE(eqx.Module):
|
|
@@ -89,7 +61,7 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
89
61
|
1. For unspecified loss term, the default is to differentiate with
|
|
90
62
|
respect to `"nn_params"` only.
|
|
91
63
|
2. No granularity inside `Params.nn_params` is currently supported.
|
|
92
|
-
3. Note that the main Params
|
|
64
|
+
3. Note that the main Params object of the problem is mandatory if initialization via `from_str()`.
|
|
93
65
|
|
|
94
66
|
A typical specification is of the form:
|
|
95
67
|
```python
|
|
@@ -105,67 +77,69 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
105
77
|
|
|
106
78
|
Parameters
|
|
107
79
|
----------
|
|
108
|
-
dyn_loss : Params |
|
|
80
|
+
dyn_loss : Params[bool] | None, default=None
|
|
109
81
|
Tell wrt which node of `Params` we will differentiate the
|
|
110
82
|
dynamic loss. To do so, the fields of `Params` contain True (if
|
|
111
83
|
differentiation) or False (if no differentiation).
|
|
112
|
-
observations : Params |
|
|
84
|
+
observations : Params[bool] | None, default=None
|
|
113
85
|
Tell wrt which parameters among Params we will differentiate the
|
|
114
86
|
observation loss. To do so, the fields of Params contain True (if
|
|
115
87
|
differentiation) or False (if no differentiation).
|
|
116
|
-
initial_condition : Params |
|
|
88
|
+
initial_condition : Params[bool] | None, default=None
|
|
117
89
|
Tell wrt which parameters among Params we will differentiate the
|
|
118
90
|
initial condition loss. To do so, the fields of Params contain True (if
|
|
119
91
|
differentiation) or False (if no differentiation).
|
|
120
|
-
params : InitVar[Params
|
|
92
|
+
params : InitVar[Params[Array]], default=None
|
|
121
93
|
The main Params object of the problem. It is required
|
|
122
94
|
if some terms are unspecified (None). This is because, jinns cannot
|
|
123
95
|
infer the content of `Params.eq_params`.
|
|
124
96
|
"""
|
|
125
97
|
|
|
126
|
-
dyn_loss: Params |
|
|
127
|
-
observations: Params |
|
|
128
|
-
initial_condition: Params |
|
|
129
|
-
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
params:
|
|
133
|
-
|
|
134
|
-
|
|
98
|
+
dyn_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
99
|
+
observations: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
100
|
+
initial_condition: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
101
|
+
|
|
102
|
+
params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
|
|
103
|
+
|
|
104
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
105
|
+
if params is None and (
|
|
106
|
+
self.dyn_loss is None
|
|
107
|
+
or self.observations is None
|
|
108
|
+
or self.initial_condition is None
|
|
109
|
+
):
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"params cannot be None since at least one loss "
|
|
112
|
+
"term has an undefined derivative key Params PyTree"
|
|
113
|
+
)
|
|
135
114
|
if self.dyn_loss is None:
|
|
136
|
-
|
|
137
|
-
self.dyn_loss
|
|
138
|
-
|
|
139
|
-
raise ValueError(
|
|
140
|
-
"self.dyn_loss is None, hence params should be " "passed"
|
|
141
|
-
)
|
|
115
|
+
if params is None:
|
|
116
|
+
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
117
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
142
118
|
if self.observations is None:
|
|
143
|
-
|
|
144
|
-
self.observations = _get_masked_parameters("nn_params", params)
|
|
145
|
-
except AttributeError:
|
|
119
|
+
if params is None:
|
|
146
120
|
raise ValueError(
|
|
147
|
-
"self.observations is None, hence params should be
|
|
121
|
+
"self.observations is None, hence params should be passed"
|
|
148
122
|
)
|
|
123
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
149
124
|
if self.initial_condition is None:
|
|
150
|
-
|
|
151
|
-
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
152
|
-
except AttributeError:
|
|
125
|
+
if params is None:
|
|
153
126
|
raise ValueError(
|
|
154
|
-
"self.initial_condition is None, hence params should be
|
|
127
|
+
"self.initial_condition is None, hence params should be passed"
|
|
155
128
|
)
|
|
129
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
156
130
|
|
|
157
131
|
@classmethod
|
|
158
132
|
def from_str(
|
|
159
133
|
cls,
|
|
160
|
-
params: Params
|
|
134
|
+
params: Params[Array],
|
|
161
135
|
dyn_loss: (
|
|
162
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
136
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
163
137
|
) = "nn_params",
|
|
164
138
|
observations: (
|
|
165
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
139
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
166
140
|
) = "nn_params",
|
|
167
141
|
initial_condition: (
|
|
168
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
142
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
169
143
|
) = "nn_params",
|
|
170
144
|
):
|
|
171
145
|
"""
|
|
@@ -181,19 +155,19 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
181
155
|
Parameters
|
|
182
156
|
----------
|
|
183
157
|
params
|
|
184
|
-
The actual Params
|
|
158
|
+
The actual Params object of the problem.
|
|
185
159
|
dyn_loss
|
|
186
160
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
187
161
|
`"both"` we will differentiate the dynamic loss. Default is
|
|
188
|
-
`"nn_params"`. Specifying a Params
|
|
162
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
189
163
|
observations
|
|
190
164
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
191
165
|
`"both"` we will differentiate the observations. Default is
|
|
192
|
-
`"nn_params"`. Specifying a Params
|
|
166
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
193
167
|
initial_condition
|
|
194
168
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
195
169
|
`"both"` we will differentiate the initial condition. Default is
|
|
196
|
-
`"nn_params"`. Specifying a Params
|
|
170
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
197
171
|
"""
|
|
198
172
|
return DerivativeKeysODE(
|
|
199
173
|
dyn_loss=(
|
|
@@ -220,78 +194,74 @@ class DerivativeKeysPDEStatio(eqx.Module):
|
|
|
220
194
|
|
|
221
195
|
Parameters
|
|
222
196
|
----------
|
|
223
|
-
dyn_loss : Params |
|
|
197
|
+
dyn_loss : Params[bool] | None, default=None
|
|
224
198
|
Tell wrt which parameters among Params we will differentiate the
|
|
225
199
|
dynamic loss. To do so, the fields of Params contain True (if
|
|
226
200
|
differentiation) or False (if no differentiation).
|
|
227
|
-
observations : Params |
|
|
201
|
+
observations : Params[bool] | None, default=None
|
|
228
202
|
Tell wrt which parameters among Params we will differentiate the
|
|
229
203
|
observation loss. To do so, the fields of Params contain True (if
|
|
230
204
|
differentiation) or False (if no differentiation).
|
|
231
|
-
boundary_loss : Params |
|
|
205
|
+
boundary_loss : Params[bool] | None, default=None
|
|
232
206
|
Tell wrt which parameters among Params we will differentiate the
|
|
233
207
|
boundary loss. To do so, the fields of Params contain True (if
|
|
234
208
|
differentiation) or False (if no differentiation).
|
|
235
|
-
norm_loss : Params |
|
|
209
|
+
norm_loss : Params[bool] | None, default=None
|
|
236
210
|
Tell wrt which parameters among Params we will differentiate the
|
|
237
211
|
normalization loss. To do so, the fields of Params contain True (if
|
|
238
212
|
differentiation) or False (if no differentiation).
|
|
239
|
-
params : InitVar[Params
|
|
213
|
+
params : InitVar[Params[Array]], default=None
|
|
240
214
|
The main Params object of the problem. It is required
|
|
241
215
|
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
242
216
|
content of `Params.eq_params`.
|
|
243
217
|
"""
|
|
244
218
|
|
|
245
|
-
dyn_loss: Params |
|
|
246
|
-
observations: Params |
|
|
247
|
-
boundary_loss: Params |
|
|
248
|
-
norm_loss: Params |
|
|
219
|
+
dyn_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
220
|
+
observations: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
221
|
+
boundary_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
222
|
+
norm_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
249
223
|
|
|
250
|
-
params: InitVar[Params |
|
|
224
|
+
params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
|
|
251
225
|
|
|
252
|
-
def __post_init__(self, params=None):
|
|
226
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
253
227
|
if self.dyn_loss is None:
|
|
254
|
-
|
|
255
|
-
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
256
|
-
except AttributeError:
|
|
228
|
+
if params is None:
|
|
257
229
|
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
230
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
258
231
|
if self.observations is None:
|
|
259
|
-
|
|
260
|
-
self.observations = _get_masked_parameters("nn_params", params)
|
|
261
|
-
except AttributeError:
|
|
232
|
+
if params is None:
|
|
262
233
|
raise ValueError(
|
|
263
234
|
"self.observations is None, hence params should be passed"
|
|
264
235
|
)
|
|
236
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
265
237
|
if self.boundary_loss is None:
|
|
266
|
-
|
|
267
|
-
self.boundary_loss = _get_masked_parameters("nn_params", params)
|
|
268
|
-
except AttributeError:
|
|
238
|
+
if params is None:
|
|
269
239
|
raise ValueError(
|
|
270
240
|
"self.boundary_loss is None, hence params should be passed"
|
|
271
241
|
)
|
|
242
|
+
self.boundary_loss = _get_masked_parameters("nn_params", params)
|
|
272
243
|
if self.norm_loss is None:
|
|
273
|
-
|
|
274
|
-
self.norm_loss = _get_masked_parameters("nn_params", params)
|
|
275
|
-
except AttributeError:
|
|
244
|
+
if params is None:
|
|
276
245
|
raise ValueError(
|
|
277
246
|
"self.norm_loss is None, hence params should be passed"
|
|
278
247
|
)
|
|
248
|
+
self.norm_loss = _get_masked_parameters("nn_params", params)
|
|
279
249
|
|
|
280
250
|
@classmethod
|
|
281
251
|
def from_str(
|
|
282
252
|
cls,
|
|
283
|
-
params: Params
|
|
253
|
+
params: Params[Array],
|
|
284
254
|
dyn_loss: (
|
|
285
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
255
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
286
256
|
) = "nn_params",
|
|
287
257
|
observations: (
|
|
288
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
258
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
289
259
|
) = "nn_params",
|
|
290
260
|
boundary_loss: (
|
|
291
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
261
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
292
262
|
) = "nn_params",
|
|
293
263
|
norm_loss: (
|
|
294
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
264
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
295
265
|
) = "nn_params",
|
|
296
266
|
):
|
|
297
267
|
"""
|
|
@@ -300,23 +270,23 @@ class DerivativeKeysPDEStatio(eqx.Module):
|
|
|
300
270
|
Parameters
|
|
301
271
|
----------
|
|
302
272
|
params
|
|
303
|
-
The actual Param
|
|
273
|
+
The actual Param object of the problem.
|
|
304
274
|
dyn_loss
|
|
305
275
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
306
276
|
`"both"` we will differentiate the dynamic loss. Default is
|
|
307
|
-
`"nn_params"`. Specifying a Params
|
|
277
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
308
278
|
observations
|
|
309
279
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
310
280
|
`"both"` we will differentiate the observations. Default is
|
|
311
|
-
`"nn_params"`. Specifying a Params
|
|
281
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
312
282
|
boundary_loss
|
|
313
283
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
314
284
|
`"both"` we will differentiate the boundary loss. Default is
|
|
315
|
-
`"nn_params"`. Specifying a Params
|
|
285
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
316
286
|
norm_loss
|
|
317
287
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
318
288
|
`"both"` we will differentiate the normalization loss. Default is
|
|
319
|
-
`"nn_params"`. Specifying a Params
|
|
289
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
320
290
|
"""
|
|
321
291
|
return DerivativeKeysPDEStatio(
|
|
322
292
|
dyn_loss=(
|
|
@@ -348,64 +318,61 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
|
348
318
|
|
|
349
319
|
Parameters
|
|
350
320
|
----------
|
|
351
|
-
dyn_loss : Params |
|
|
321
|
+
dyn_loss : Params[bool] | None, default=None
|
|
352
322
|
Tell wrt which parameters among Params we will differentiate the
|
|
353
323
|
dynamic loss. To do so, the fields of Params contain True (if
|
|
354
324
|
differentiation) or False (if no differentiation).
|
|
355
|
-
observations : Params |
|
|
325
|
+
observations : Params[bool] | None, default=None
|
|
356
326
|
Tell wrt which parameters among Params we will differentiate the
|
|
357
327
|
observation loss. To do so, the fields of Params contain True (if
|
|
358
328
|
differentiation) or False (if no differentiation).
|
|
359
|
-
boundary_loss : Params |
|
|
329
|
+
boundary_loss : Params[bool] | None, default=None
|
|
360
330
|
Tell wrt which parameters among Params we will differentiate the
|
|
361
331
|
boundary loss. To do so, the fields of Params contain True (if
|
|
362
332
|
differentiation) or False (if no differentiation).
|
|
363
|
-
norm_loss : Params |
|
|
333
|
+
norm_loss : Params[bool] | None, default=None
|
|
364
334
|
Tell wrt which parameters among Params we will differentiate the
|
|
365
335
|
normalization loss. To do so, the fields of Params contain True (if
|
|
366
336
|
differentiation) or False (if no differentiation).
|
|
367
|
-
initial_condition : Params |
|
|
337
|
+
initial_condition : Params[bool] | None, default=None
|
|
368
338
|
Tell wrt which parameters among Params we will differentiate the
|
|
369
339
|
initial_condition loss. To do so, the fields of Params contain True (if
|
|
370
340
|
differentiation) or False (if no differentiation).
|
|
371
|
-
params : InitVar[Params
|
|
341
|
+
params : InitVar[Params[Array]], default=None
|
|
372
342
|
The main Params object of the problem. It is required
|
|
373
343
|
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
374
344
|
content of `Params.eq_params`.
|
|
375
345
|
"""
|
|
376
346
|
|
|
377
|
-
initial_condition: Params |
|
|
378
|
-
kw_only=True, default=None
|
|
379
|
-
)
|
|
347
|
+
initial_condition: Params[bool] | None = eqx.field(kw_only=True, default=None)
|
|
380
348
|
|
|
381
|
-
def __post_init__(self, params=None):
|
|
349
|
+
def __post_init__(self, params: Params[Array] | None = None):
|
|
382
350
|
super().__post_init__(params=params)
|
|
383
351
|
if self.initial_condition is None:
|
|
384
|
-
|
|
385
|
-
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
386
|
-
except AttributeError:
|
|
352
|
+
if params is None:
|
|
387
353
|
raise ValueError(
|
|
388
354
|
"self.initial_condition is None, hence params should be passed"
|
|
389
355
|
)
|
|
356
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
390
357
|
|
|
391
358
|
@classmethod
|
|
392
359
|
def from_str(
|
|
393
360
|
cls,
|
|
394
|
-
params: Params
|
|
361
|
+
params: Params[Array],
|
|
395
362
|
dyn_loss: (
|
|
396
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
363
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
397
364
|
) = "nn_params",
|
|
398
365
|
observations: (
|
|
399
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
366
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
400
367
|
) = "nn_params",
|
|
401
368
|
boundary_loss: (
|
|
402
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
369
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
403
370
|
) = "nn_params",
|
|
404
371
|
norm_loss: (
|
|
405
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
372
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
406
373
|
) = "nn_params",
|
|
407
374
|
initial_condition: (
|
|
408
|
-
Literal["nn_params", "eq_params", "both"] | Params
|
|
375
|
+
Literal["nn_params", "eq_params", "both"] | Params[bool]
|
|
409
376
|
) = "nn_params",
|
|
410
377
|
):
|
|
411
378
|
"""
|
|
@@ -414,27 +381,27 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
|
414
381
|
Parameters
|
|
415
382
|
----------
|
|
416
383
|
params
|
|
417
|
-
The actual Params
|
|
384
|
+
The actual Params object of the problem.
|
|
418
385
|
dyn_loss
|
|
419
386
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
420
387
|
`"both"` we will differentiate the dynamic loss. Default is
|
|
421
|
-
`"nn_params"`. Specifying a Params
|
|
388
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
422
389
|
observations
|
|
423
390
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
424
391
|
`"both"` we will differentiate the observations. Default is
|
|
425
|
-
`"nn_params"`. Specifying a Params
|
|
392
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
426
393
|
boundary_loss
|
|
427
394
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
428
395
|
`"both"` we will differentiate the boundary loss. Default is
|
|
429
|
-
`"nn_params"`. Specifying a Params
|
|
396
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
430
397
|
norm_loss
|
|
431
398
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
432
399
|
`"both"` we will differentiate the normalization loss. Default is
|
|
433
|
-
`"nn_params"`. Specifying a Params
|
|
400
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
434
401
|
initial_condition
|
|
435
402
|
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
436
403
|
`"both"` we will differentiate the initial_condition loss. Default is
|
|
437
|
-
`"nn_params"`. Specifying a Params
|
|
404
|
+
`"nn_params"`. Specifying a Params is also possible.
|
|
438
405
|
"""
|
|
439
406
|
return DerivativeKeysPDENonStatio(
|
|
440
407
|
dyn_loss=(
|
|
@@ -471,32 +438,6 @@ def _set_derivatives(params, derivative_keys):
|
|
|
471
438
|
has a copy of the params with appropriate derivatives set
|
|
472
439
|
"""
|
|
473
440
|
|
|
474
|
-
def _set_derivatives_ParamsDict(params_, derivative_mask):
|
|
475
|
-
"""
|
|
476
|
-
The next lines put a stop_gradient around the fields that do not
|
|
477
|
-
differentiate the loss term
|
|
478
|
-
**Note:** **No granularity inside `ParamsDict.nn_params` is currently
|
|
479
|
-
supported.**
|
|
480
|
-
This means a typical Params specification is of the form:
|
|
481
|
-
`ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
|
|
482
|
-
"beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
|
|
483
|
-
"""
|
|
484
|
-
# a ParamsDict object is reconstructed by hand since we do not want to
|
|
485
|
-
# traverse nn_params, for now...
|
|
486
|
-
return ParamsDict(
|
|
487
|
-
nn_params=jax.lax.cond(
|
|
488
|
-
derivative_mask.nn_params,
|
|
489
|
-
lambda p: p,
|
|
490
|
-
jax.lax.stop_gradient,
|
|
491
|
-
params_.nn_params,
|
|
492
|
-
),
|
|
493
|
-
eq_params=jax.tree.map(
|
|
494
|
-
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
495
|
-
params_.eq_params,
|
|
496
|
-
derivative_mask.eq_params,
|
|
497
|
-
),
|
|
498
|
-
)
|
|
499
|
-
|
|
500
441
|
def _set_derivatives_(params_, derivative_mask):
|
|
501
442
|
"""
|
|
502
443
|
The next lines put a stop_gradient around the fields that do not
|
|
@@ -516,6 +457,4 @@ def _set_derivatives(params, derivative_keys):
|
|
|
516
457
|
# granularity could be imagined here, in the future
|
|
517
458
|
)
|
|
518
459
|
|
|
519
|
-
if isinstance(params, ParamsDict):
|
|
520
|
-
return _set_derivatives_ParamsDict(params, derivative_keys)
|
|
521
460
|
return _set_derivatives_(params, derivative_keys)
|
jinns/parameters/_params.py
CHANGED
|
@@ -2,67 +2,36 @@
|
|
|
2
2
|
Formalize the data structure for the parameters
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Generic, TypeVar
|
|
5
6
|
import jax
|
|
6
7
|
import equinox as eqx
|
|
7
|
-
from
|
|
8
|
-
from jaxtyping import Array, PyTree
|
|
8
|
+
from jaxtyping import Array, PyTree, Float
|
|
9
9
|
|
|
10
|
+
T = TypeVar("T") # the generic type for what is in the Params PyTree because we
|
|
11
|
+
# have possibly Params of Arrays, boolean, ...
|
|
10
12
|
|
|
11
|
-
|
|
13
|
+
|
|
14
|
+
class Params(eqx.Module, Generic[T]):
|
|
12
15
|
"""
|
|
13
16
|
The equinox module for the parameters
|
|
14
17
|
|
|
15
18
|
Parameters
|
|
16
19
|
----------
|
|
17
|
-
nn_params :
|
|
20
|
+
nn_params : PyTree[T]
|
|
18
21
|
A PyTree of the non-static part of the PINN eqx.Module, i.e., the
|
|
19
22
|
parameters of the PINN
|
|
20
|
-
eq_params :
|
|
23
|
+
eq_params : dict[str, T]
|
|
21
24
|
A dictionary of the equation parameters. Keys are the parameter name,
|
|
22
25
|
values are their corresponding value
|
|
23
26
|
"""
|
|
24
27
|
|
|
25
|
-
nn_params: PyTree = eqx.field(kw_only=True, default=None)
|
|
26
|
-
eq_params:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ParamsDict(eqx.Module):
|
|
30
|
-
"""
|
|
31
|
-
The equinox module for a dictionnary of parameters with different keys
|
|
32
|
-
corresponding to different equations.
|
|
33
|
-
|
|
34
|
-
Parameters
|
|
35
|
-
----------
|
|
36
|
-
nn_params : Dict[str, PyTree]
|
|
37
|
-
The neural network's parameters. Most of the time, it will be the
|
|
38
|
-
Array part of an `eqx.Module` obtained by
|
|
39
|
-
`eqx.partition(module, eqx.is_inexact_array)`.
|
|
40
|
-
eq_params : Dict[str, Array]
|
|
41
|
-
A dictionary of the equation parameters. Dict keys are the parameter name as defined your custom loss.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
nn_params: Dict[str, PyTree] = eqx.field(kw_only=True, default=None)
|
|
45
|
-
eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
|
|
46
|
-
|
|
47
|
-
def extract_params(self, nn_key: str) -> Params:
|
|
48
|
-
"""
|
|
49
|
-
Extract the corresponding `nn_params` and `eq_params` for `nn_key` and
|
|
50
|
-
return them in the form of a `Params` object.
|
|
51
|
-
"""
|
|
52
|
-
try:
|
|
53
|
-
return Params(
|
|
54
|
-
nn_params=self.nn_params[nn_key],
|
|
55
|
-
eq_params=self.eq_params[nn_key],
|
|
56
|
-
)
|
|
57
|
-
except (KeyError, IndexError) as e:
|
|
58
|
-
return Params(
|
|
59
|
-
nn_params=self.nn_params[nn_key],
|
|
60
|
-
eq_params=self.eq_params,
|
|
61
|
-
)
|
|
28
|
+
nn_params: PyTree[T] = eqx.field(kw_only=True, default=None)
|
|
29
|
+
eq_params: dict[str, T] = eqx.field(kw_only=True, default=None)
|
|
62
30
|
|
|
63
31
|
|
|
64
32
|
def _update_eq_params_dict(
|
|
65
|
-
params: Params
|
|
33
|
+
params: Params[Array],
|
|
34
|
+
param_batch_dict: dict[str, Float[Array, " param_batch_size dim"]],
|
|
66
35
|
) -> Params:
|
|
67
36
|
"""
|
|
68
37
|
Update params.eq_params with a batch of eq_params for given key(s)
|
|
@@ -89,13 +58,16 @@ def _update_eq_params_dict(
|
|
|
89
58
|
|
|
90
59
|
|
|
91
60
|
def _get_vmap_in_axes_params(
|
|
92
|
-
eq_params_batch_dict:
|
|
93
|
-
) -> tuple[Params]:
|
|
61
|
+
eq_params_batch_dict: dict[str, Array], params: Params[Array]
|
|
62
|
+
) -> tuple[Params[int | None] | None]:
|
|
94
63
|
"""
|
|
95
64
|
Return the input vmap axes when there is batch(es) of parameters to vmap
|
|
96
65
|
over. The latter are designated by keys in eq_params_batch_dict.
|
|
97
66
|
If eq_params_batch_dict is None (i.e. no additional parameter batch), we
|
|
98
67
|
return (None,).
|
|
68
|
+
|
|
69
|
+
Note that we return a Params PyTree with an integer to designate the
|
|
70
|
+
vmapped axis or None if there is not
|
|
99
71
|
"""
|
|
100
72
|
if eq_params_batch_dict is None:
|
|
101
73
|
return (None,)
|
|
@@ -104,7 +76,7 @@ def _get_vmap_in_axes_params(
|
|
|
104
76
|
# this is for a fine-grained vmaping
|
|
105
77
|
# scheme over the params
|
|
106
78
|
vmap_in_axes_params = (
|
|
107
|
-
|
|
79
|
+
Params(
|
|
108
80
|
nn_params=None,
|
|
109
81
|
eq_params={
|
|
110
82
|
k: (0 if k in eq_params_batch_dict.keys() else None)
|