jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/nn/_save_load.py
CHANGED
|
@@ -3,6 +3,7 @@ Implements save and load functions
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from typing import Callable, Literal
|
|
6
|
+
from dataclasses import fields
|
|
6
7
|
import pickle
|
|
7
8
|
import jax
|
|
8
9
|
import equinox as eqx
|
|
@@ -130,8 +131,14 @@ def save_pinn(
|
|
|
130
131
|
u = eqx.tree_at(lambda m: m.init_params, u, params)
|
|
131
132
|
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
132
133
|
|
|
134
|
+
# The class EqParams is malformed for pickling, hence we pickle it under
|
|
135
|
+
# its dictionary form
|
|
136
|
+
eq_params_as_dict = {
|
|
137
|
+
k.name: getattr(params.eq_params, k.name) for k in fields(params.eq_params)
|
|
138
|
+
}
|
|
139
|
+
|
|
133
140
|
with open(filename + "-eq_params.pkl", "wb") as f:
|
|
134
|
-
pickle.dump(
|
|
141
|
+
pickle.dump(eq_params_as_dict, f)
|
|
135
142
|
|
|
136
143
|
kwargs_creation = kwargs_creation.copy() # avoid side-effect that would be
|
|
137
144
|
# very probably harmless anyway
|
|
@@ -187,9 +194,9 @@ def load_pinn(
|
|
|
187
194
|
try:
|
|
188
195
|
with open(filename + "-eq_params.pkl", "rb") as f:
|
|
189
196
|
eq_params_reloaded = pickle.load(f)
|
|
190
|
-
except FileNotFoundError:
|
|
191
|
-
|
|
192
|
-
|
|
197
|
+
except FileNotFoundError as e:
|
|
198
|
+
raise e
|
|
199
|
+
|
|
193
200
|
kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
|
|
194
201
|
if type_ == "pinn_mlp":
|
|
195
202
|
# next line creates a shallow model, the jax arrays are just shapes and
|
jinns/nn/_spinn.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Union, Callable, Any, Literal
|
|
2
|
+
from typing import Union, Callable, Any, Literal
|
|
3
3
|
from dataclasses import InitVar
|
|
4
4
|
from jaxtyping import PyTree, Float, Array
|
|
5
5
|
import jax
|
|
@@ -8,7 +8,6 @@ import equinox as eqx
|
|
|
8
8
|
|
|
9
9
|
from jinns.parameters._params import Params
|
|
10
10
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
11
|
-
from jinns.nn._utils import _PyTree_to_Params
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class SPINN(AbstractPINN):
|
|
@@ -72,17 +71,6 @@ class SPINN(AbstractPINN):
|
|
|
72
71
|
eqx_spinn_network, self.filter_spec
|
|
73
72
|
)
|
|
74
73
|
|
|
75
|
-
@overload
|
|
76
|
-
@_PyTree_to_Params
|
|
77
|
-
def __call__(
|
|
78
|
-
self,
|
|
79
|
-
inputs: Float[Array, " input_dim"],
|
|
80
|
-
params: PyTree,
|
|
81
|
-
*args,
|
|
82
|
-
**kwargs,
|
|
83
|
-
) -> Float[Array, " output_dim"]: ...
|
|
84
|
-
|
|
85
|
-
@_PyTree_to_Params
|
|
86
74
|
def __call__(
|
|
87
75
|
self,
|
|
88
76
|
t_x: Float[Array, " batch_size 1+dim"],
|
|
@@ -94,10 +82,7 @@ class SPINN(AbstractPINN):
|
|
|
94
82
|
Note that that thanks to the decorator, params can also directly be the
|
|
95
83
|
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
96
84
|
"""
|
|
97
|
-
# try:
|
|
98
85
|
spinn = eqx.combine(params.nn_params, self.static)
|
|
99
|
-
# except (KeyError, AttributeError, TypeError) as e:
|
|
100
|
-
# spinn = eqx.combine(params, self.static)
|
|
101
86
|
v_model = jax.vmap(spinn)
|
|
102
87
|
res = v_model(t_x) # type: ignore
|
|
103
88
|
|
jinns/nn/_spinn_mlp.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Callable, Literal, Self, Union, Any, TypeGuard
|
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
import equinox as eqx
|
|
11
|
-
from jaxtyping import
|
|
11
|
+
from jaxtyping import PRNGKeyArray, Array, Float, PyTree
|
|
12
12
|
|
|
13
13
|
from jinns.nn._mlp import MLP
|
|
14
14
|
from jinns.nn._spinn import SPINN
|
|
@@ -20,7 +20,7 @@ class SMLP(eqx.Module):
|
|
|
20
20
|
|
|
21
21
|
Parameters
|
|
22
22
|
----------
|
|
23
|
-
key : InitVar[
|
|
23
|
+
key : InitVar[PRNGKeyArray]
|
|
24
24
|
A jax random key for the layer initializations.
|
|
25
25
|
d : int
|
|
26
26
|
The number of dimensions to treat separately, including time `t` if
|
|
@@ -42,7 +42,7 @@ class SMLP(eqx.Module):
|
|
|
42
42
|
)`.
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
|
-
key: InitVar[
|
|
45
|
+
key: InitVar[PRNGKeyArray] = eqx.field(kw_only=True)
|
|
46
46
|
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
47
47
|
eqx.field(kw_only=True)
|
|
48
48
|
)
|
|
@@ -74,7 +74,7 @@ class SPINN_MLP(SPINN):
|
|
|
74
74
|
@classmethod
|
|
75
75
|
def create(
|
|
76
76
|
cls,
|
|
77
|
-
key:
|
|
77
|
+
key: PRNGKeyArray,
|
|
78
78
|
d: int,
|
|
79
79
|
r: int,
|
|
80
80
|
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
@@ -93,7 +93,7 @@ class SPINN_MLP(SPINN):
|
|
|
93
93
|
|
|
94
94
|
Parameters
|
|
95
95
|
----------
|
|
96
|
-
key :
|
|
96
|
+
key : PRNGKeyArray
|
|
97
97
|
A JAX random key that will be used to initialize the network parameters
|
|
98
98
|
d : int
|
|
99
99
|
The number of dimensions to treat separately.
|
jinns/nn/_utils.py
CHANGED
|
@@ -1,38 +1,33 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
*args
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
if isinstance(params, PyTree) and not isinstance(params, Params):
|
|
35
|
-
params = Params(nn_params=params, eq_params={})
|
|
36
|
-
return call_fun(self, inputs, params, *args, **kwargs)
|
|
37
|
-
|
|
38
|
-
return wrapper
|
|
1
|
+
# P = ParamSpec("P")
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# def _PyTree_to_Params(
|
|
5
|
+
# call_fun: Callable[
|
|
6
|
+
# Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
7
|
+
# Any,
|
|
8
|
+
# ],
|
|
9
|
+
# ) -> Callable[
|
|
10
|
+
# Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
11
|
+
# Any,
|
|
12
|
+
# ]:
|
|
13
|
+
# """
|
|
14
|
+
# Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
|
|
15
|
+
# authorizes the __call__ with `params` being directly be the
|
|
16
|
+
# PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
|
|
17
|
+
#
|
|
18
|
+
# This generic approach enables to cleanly handle type hints, up to the small
|
|
19
|
+
# effort required to understand type hints for decorators (ie ParamSpec).
|
|
20
|
+
# """
|
|
21
|
+
#
|
|
22
|
+
# def wrapper(
|
|
23
|
+
# self: Any,
|
|
24
|
+
# inputs: Any,
|
|
25
|
+
# params: PyTree | Params[Array],
|
|
26
|
+
# *args: P.args,
|
|
27
|
+
# **kwargs: P.kwargs,
|
|
28
|
+
# ):
|
|
29
|
+
# if isinstance(params, PyTree) and not isinstance(params, Params):
|
|
30
|
+
# params = Params(nn_params=params, eq_params={})
|
|
31
|
+
# return call_fun(self, inputs, params, *args, **kwargs)
|
|
32
|
+
#
|
|
33
|
+
# return wrapper
|
jinns/parameters/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from ._params import Params
|
|
1
|
+
from ._params import EqParams, Params, update_eq_params
|
|
2
2
|
from ._derivative_keys import (
|
|
3
3
|
DerivativeKeysODE,
|
|
4
4
|
DerivativeKeysPDEStatio,
|
|
@@ -6,8 +6,10 @@ from ._derivative_keys import (
|
|
|
6
6
|
)
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
|
+
"EqParams",
|
|
9
10
|
"Params",
|
|
10
11
|
"DerivativeKeysODE",
|
|
11
12
|
"DerivativeKeysPDEStatio",
|
|
12
13
|
"DerivativeKeysPDENonStatio",
|
|
14
|
+
"update_eq_params",
|
|
13
15
|
]
|
|
@@ -19,13 +19,10 @@ def _get_masked_parameters(
|
|
|
19
19
|
"""
|
|
20
20
|
# start with a params object with True everywhere. We will update to False
|
|
21
21
|
# for parameters wrt which we do want not to differentiate the loss
|
|
22
|
-
diff_params =
|
|
23
|
-
lambda
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
and not isinstance(x, Params), # do not travers nn_params, more
|
|
27
|
-
# granularity could be imagined here, in the future
|
|
28
|
-
)
|
|
22
|
+
diff_params = Params(
|
|
23
|
+
nn_params=True, eq_params=jax.tree.map(lambda _: True, params.eq_params)
|
|
24
|
+
) # do not travers nn_params, more
|
|
25
|
+
# granularity could be imagined here, in the future
|
|
29
26
|
if derivative_mask_str == "both":
|
|
30
27
|
return diff_params
|
|
31
28
|
if derivative_mask_str == "eq_params":
|
|
@@ -60,7 +57,7 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
60
57
|
|
|
61
58
|
1. For unspecified loss term, the default is to differentiate with
|
|
62
59
|
respect to `"nn_params"` only.
|
|
63
|
-
2. No granularity inside `Params.nn_params` is currently supported.
|
|
60
|
+
2. No granularity inside `Params.nn_params` is currently supported. An easy way to do freeze part of a custom PINN module is to use `jax.lax.stop_gradient` as explained [here](https://docs.kidger.site/equinox/faq/#how-to-mark-arrays-as-non-trainable-like-pytorchs-buffers).
|
|
64
61
|
3. Note that the main Params object of the problem is mandatory if initialization via `from_str()`.
|
|
65
62
|
|
|
66
63
|
A typical specification is of the form:
|
|
@@ -95,38 +92,52 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
95
92
|
infer the content of `Params.eq_params`.
|
|
96
93
|
"""
|
|
97
94
|
|
|
98
|
-
dyn_loss: Params[bool]
|
|
99
|
-
observations: Params[bool]
|
|
100
|
-
initial_condition: Params[bool]
|
|
95
|
+
dyn_loss: Params[bool]
|
|
96
|
+
observations: Params[bool]
|
|
97
|
+
initial_condition: Params[bool]
|
|
101
98
|
|
|
102
|
-
params: InitVar[Params[Array] | None]
|
|
99
|
+
params: InitVar[Params[Array] | None]
|
|
103
100
|
|
|
104
|
-
def
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
*,
|
|
104
|
+
dyn_loss: Params[bool] | None = None,
|
|
105
|
+
observations: Params[bool] | None = None,
|
|
106
|
+
initial_condition: Params[bool] | None = None,
|
|
107
|
+
params: Params[Array] | None = None,
|
|
108
|
+
):
|
|
109
|
+
super().__init__()
|
|
105
110
|
if params is None and (
|
|
106
|
-
|
|
107
|
-
or self.observations is None
|
|
108
|
-
or self.initial_condition is None
|
|
111
|
+
dyn_loss is None or observations is None or initial_condition is None
|
|
109
112
|
):
|
|
110
113
|
raise ValueError(
|
|
111
114
|
"params cannot be None since at least one loss "
|
|
112
115
|
"term has an undefined derivative key Params PyTree"
|
|
113
116
|
)
|
|
114
|
-
if
|
|
117
|
+
if dyn_loss is None:
|
|
115
118
|
if params is None:
|
|
116
119
|
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
117
120
|
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
118
|
-
|
|
121
|
+
else:
|
|
122
|
+
self.dyn_loss = dyn_loss
|
|
123
|
+
|
|
124
|
+
if observations is None:
|
|
119
125
|
if params is None:
|
|
120
126
|
raise ValueError(
|
|
121
127
|
"self.observations is None, hence params should be passed"
|
|
122
128
|
)
|
|
123
129
|
self.observations = _get_masked_parameters("nn_params", params)
|
|
124
|
-
|
|
130
|
+
else:
|
|
131
|
+
self.observations = observations
|
|
132
|
+
|
|
133
|
+
if initial_condition is None:
|
|
125
134
|
if params is None:
|
|
126
135
|
raise ValueError(
|
|
127
136
|
"self.initial_condition is None, hence params should be passed"
|
|
128
137
|
)
|
|
129
138
|
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
139
|
+
else:
|
|
140
|
+
self.initial_condition = initial_condition
|
|
130
141
|
|
|
131
142
|
@classmethod
|
|
132
143
|
def from_str(
|
|
@@ -216,36 +227,56 @@ class DerivativeKeysPDEStatio(eqx.Module):
|
|
|
216
227
|
content of `Params.eq_params`.
|
|
217
228
|
"""
|
|
218
229
|
|
|
219
|
-
dyn_loss: Params[bool]
|
|
220
|
-
observations: Params[bool]
|
|
221
|
-
boundary_loss: Params[bool]
|
|
222
|
-
norm_loss: Params[bool]
|
|
230
|
+
dyn_loss: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
231
|
+
observations: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
232
|
+
boundary_loss: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
233
|
+
norm_loss: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
223
234
|
|
|
224
235
|
params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
|
|
225
236
|
|
|
226
|
-
def
|
|
227
|
-
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
*,
|
|
240
|
+
dyn_loss: Params[bool] | None = None,
|
|
241
|
+
observations: Params[bool] | None = None,
|
|
242
|
+
boundary_loss: Params[bool] | None = None,
|
|
243
|
+
norm_loss: Params[bool] | None = None,
|
|
244
|
+
params: Params[Array] | None = None,
|
|
245
|
+
):
|
|
246
|
+
super().__init__()
|
|
247
|
+
if dyn_loss is None:
|
|
228
248
|
if params is None:
|
|
229
249
|
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
230
250
|
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
231
|
-
|
|
251
|
+
else:
|
|
252
|
+
self.dyn_loss = dyn_loss
|
|
253
|
+
|
|
254
|
+
if observations is None:
|
|
232
255
|
if params is None:
|
|
233
256
|
raise ValueError(
|
|
234
257
|
"self.observations is None, hence params should be passed"
|
|
235
258
|
)
|
|
236
259
|
self.observations = _get_masked_parameters("nn_params", params)
|
|
237
|
-
|
|
260
|
+
else:
|
|
261
|
+
self.observations = observations
|
|
262
|
+
|
|
263
|
+
if boundary_loss is None:
|
|
238
264
|
if params is None:
|
|
239
265
|
raise ValueError(
|
|
240
266
|
"self.boundary_loss is None, hence params should be passed"
|
|
241
267
|
)
|
|
242
268
|
self.boundary_loss = _get_masked_parameters("nn_params", params)
|
|
243
|
-
|
|
269
|
+
else:
|
|
270
|
+
self.boundary_loss = boundary_loss
|
|
271
|
+
|
|
272
|
+
if norm_loss is None:
|
|
244
273
|
if params is None:
|
|
245
274
|
raise ValueError(
|
|
246
275
|
"self.norm_loss is None, hence params should be passed"
|
|
247
276
|
)
|
|
248
277
|
self.norm_loss = _get_masked_parameters("nn_params", params)
|
|
278
|
+
else:
|
|
279
|
+
self.norm_loss = norm_loss
|
|
249
280
|
|
|
250
281
|
@classmethod
|
|
251
282
|
def from_str(
|
|
@@ -344,16 +375,33 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
|
344
375
|
content of `Params.eq_params`.
|
|
345
376
|
"""
|
|
346
377
|
|
|
347
|
-
initial_condition: Params[bool]
|
|
348
|
-
|
|
349
|
-
def
|
|
350
|
-
|
|
351
|
-
|
|
378
|
+
initial_condition: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
379
|
+
|
|
380
|
+
def __init__(
|
|
381
|
+
self,
|
|
382
|
+
*,
|
|
383
|
+
dyn_loss: Params[bool] | None = None,
|
|
384
|
+
observations: Params[bool] | None = None,
|
|
385
|
+
boundary_loss: Params[bool] | None = None,
|
|
386
|
+
norm_loss: Params[bool] | None = None,
|
|
387
|
+
initial_condition: Params[bool] | None = None,
|
|
388
|
+
params: Params[Array] | None = None,
|
|
389
|
+
):
|
|
390
|
+
super().__init__(
|
|
391
|
+
dyn_loss=dyn_loss,
|
|
392
|
+
observations=observations,
|
|
393
|
+
boundary_loss=boundary_loss,
|
|
394
|
+
norm_loss=norm_loss,
|
|
395
|
+
params=params,
|
|
396
|
+
)
|
|
397
|
+
if initial_condition is None:
|
|
352
398
|
if params is None:
|
|
353
399
|
raise ValueError(
|
|
354
400
|
"self.initial_condition is None, hence params should be passed"
|
|
355
401
|
)
|
|
356
402
|
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
403
|
+
else:
|
|
404
|
+
self.initial_condition = initial_condition
|
|
357
405
|
|
|
358
406
|
@classmethod
|
|
359
407
|
def from_str(
|
|
@@ -432,7 +480,9 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
|
432
480
|
)
|
|
433
481
|
|
|
434
482
|
|
|
435
|
-
def _set_derivatives(
|
|
483
|
+
def _set_derivatives(
|
|
484
|
+
params: Params[Array], derivative_keys: Params[bool]
|
|
485
|
+
) -> Params[Array]:
|
|
436
486
|
"""
|
|
437
487
|
We construct an eqx.Module with the fields of derivative_keys, each field
|
|
438
488
|
has a copy of the params with appropriate derivatives set
|
|
@@ -448,13 +498,21 @@ def _set_derivatives(params, derivative_keys):
|
|
|
448
498
|
`Params(nn_params=True | False, eq_params={"alpha":True | False,
|
|
449
499
|
"beta":True | False})`.
|
|
450
500
|
"""
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
501
|
+
|
|
502
|
+
return Params(
|
|
503
|
+
nn_params=jax.lax.cond(
|
|
504
|
+
derivative_mask.nn_params,
|
|
505
|
+
lambda p: p,
|
|
506
|
+
jax.lax.stop_gradient,
|
|
507
|
+
params_.nn_params,
|
|
508
|
+
),
|
|
509
|
+
eq_params=jax.tree.map(
|
|
510
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
511
|
+
params_.eq_params,
|
|
512
|
+
derivative_mask.eq_params,
|
|
513
|
+
),
|
|
458
514
|
)
|
|
515
|
+
# NOTE that currently we do not travers nn_params, more
|
|
516
|
+
# granularity could be imagined here, in the future
|
|
459
517
|
|
|
460
518
|
return _set_derivatives_(params, derivative_keys)
|
jinns/parameters/_params.py
CHANGED
|
@@ -2,14 +2,34 @@
|
|
|
2
2
|
Formalize the data structure for the parameters
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from dataclasses import fields
|
|
5
6
|
from typing import Generic, TypeVar
|
|
6
|
-
import jax
|
|
7
7
|
import equinox as eqx
|
|
8
|
-
from jaxtyping import Array, PyTree
|
|
8
|
+
from jaxtyping import Array, PyTree
|
|
9
|
+
|
|
10
|
+
from jinns.utils._DictToModuleMeta import DictToModuleMeta
|
|
9
11
|
|
|
10
12
|
T = TypeVar("T") # the generic type for what is in the Params PyTree because we
|
|
11
13
|
# have possibly Params of Arrays, boolean, ...
|
|
12
14
|
|
|
15
|
+
### NOTE
|
|
16
|
+
### We are taking derivatives with respect to Params eqx.Modules.
|
|
17
|
+
### This has been shown to behave weirdly if some fields of eqx.Modules have
|
|
18
|
+
### been set as `field(init=False)`, we then should never create such fields in
|
|
19
|
+
### jinns' Params modules.
|
|
20
|
+
### We currently have silenced the warning related to this (see jinns.__init__
|
|
21
|
+
### see https://github.com/patrick-kidger/equinox/pull/1043/commits/f88e62ab809140334c2f987ed13eff0d80b8be13
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EqParams(metaclass=DictToModuleMeta):
|
|
25
|
+
"""
|
|
26
|
+
Note that this is exposed to the user for the particular case where the
|
|
27
|
+
user, during its work, wants to change the equation parameters. In this
|
|
28
|
+
case, the user must import EqParams and call `EqParams.clear()`
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
pass
|
|
32
|
+
|
|
13
33
|
|
|
14
34
|
class Params(eqx.Module, Generic[T]):
|
|
15
35
|
"""
|
|
@@ -20,37 +40,47 @@ class Params(eqx.Module, Generic[T]):
|
|
|
20
40
|
nn_params : PyTree[T]
|
|
21
41
|
A PyTree of the non-static part of the PINN eqx.Module, i.e., the
|
|
22
42
|
parameters of the PINN
|
|
23
|
-
eq_params :
|
|
24
|
-
A
|
|
25
|
-
values are their corresponding
|
|
43
|
+
eq_params : PyTree[T]
|
|
44
|
+
A PyTree of the equation parameters. For retrocompatibility it us
|
|
45
|
+
provided as a dictionary of the equation parameters where keys are the parameter names, and values are their corresponding values. Internally,
|
|
46
|
+
it will be transformed to a custom instance of `EqParams`.
|
|
26
47
|
"""
|
|
27
48
|
|
|
28
|
-
nn_params: PyTree[T]
|
|
29
|
-
eq_params:
|
|
49
|
+
nn_params: PyTree[T]
|
|
50
|
+
eq_params: PyTree[T]
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
nn_params: PyTree[T] | None = None,
|
|
55
|
+
eq_params: dict[str, T] | None = None,
|
|
56
|
+
):
|
|
57
|
+
self.nn_params = nn_params
|
|
58
|
+
if isinstance(eq_params, dict):
|
|
59
|
+
self.eq_params = EqParams(eq_params, "EqParams")
|
|
60
|
+
else:
|
|
61
|
+
self.eq_params = eq_params
|
|
30
62
|
|
|
31
63
|
|
|
32
|
-
def
|
|
64
|
+
def update_eq_params(
|
|
33
65
|
params: Params[Array],
|
|
34
|
-
|
|
35
|
-
) -> Params:
|
|
66
|
+
eq_param_batch: PyTree[Array] | None,
|
|
67
|
+
) -> Params[Array]:
|
|
36
68
|
"""
|
|
37
69
|
Update params.eq_params with a batch of eq_params for given key(s)
|
|
38
70
|
"""
|
|
39
71
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
param_batch_dict_ = param_batch_dict | {
|
|
43
|
-
k: None for k in set(params.eq_params.keys()) - set(param_batch_dict.keys())
|
|
44
|
-
}
|
|
72
|
+
if eq_param_batch is None:
|
|
73
|
+
return params
|
|
45
74
|
|
|
46
|
-
|
|
75
|
+
param_names_to_update = tuple(f.name for f in fields(eq_param_batch))
|
|
47
76
|
params = eqx.tree_at(
|
|
48
77
|
lambda p: p.eq_params,
|
|
49
78
|
params,
|
|
50
|
-
|
|
51
|
-
lambda
|
|
79
|
+
eqx.tree_at(
|
|
80
|
+
lambda pt: tuple(getattr(pt, f) for f in param_names_to_update),
|
|
52
81
|
params.eq_params,
|
|
53
|
-
|
|
82
|
+
tuple(getattr(eq_param_batch, f) for f in param_names_to_update),
|
|
83
|
+
is_leaf=lambda x: x is None or eqx.is_inexact_array(x),
|
|
54
84
|
),
|
|
55
85
|
)
|
|
56
86
|
|
|
@@ -58,7 +88,7 @@ def _update_eq_params_dict(
|
|
|
58
88
|
|
|
59
89
|
|
|
60
90
|
def _get_vmap_in_axes_params(
|
|
61
|
-
|
|
91
|
+
eq_param_batch: eqx.Module | None, params: Params[Array]
|
|
62
92
|
) -> tuple[Params[int | None] | None]:
|
|
63
93
|
"""
|
|
64
94
|
Return the input vmap axes when there is batch(es) of parameters to vmap
|
|
@@ -69,19 +99,22 @@ def _get_vmap_in_axes_params(
|
|
|
69
99
|
Note that we return a Params PyTree with an integer to designate the
|
|
70
100
|
vmapped axis or None if there is not
|
|
71
101
|
"""
|
|
72
|
-
if
|
|
102
|
+
if eq_param_batch is None:
|
|
73
103
|
return (None,)
|
|
74
104
|
# We use pytree indexing of vmapped axes and vmap on axis
|
|
75
105
|
# 0 of the eq_parameters for which we have a batch
|
|
76
106
|
# this is for a fine-grained vmaping
|
|
77
107
|
# scheme over the params
|
|
108
|
+
param_names_to_vmap = tuple(f.name for f in fields(eq_param_batch))
|
|
109
|
+
vmap_axes_dict = {
|
|
110
|
+
k.name: (0 if k.name in param_names_to_vmap else None)
|
|
111
|
+
for k in fields(params.eq_params)
|
|
112
|
+
}
|
|
113
|
+
eq_param_vmap_axes = type(params.eq_params)(**vmap_axes_dict)
|
|
78
114
|
vmap_in_axes_params = (
|
|
79
115
|
Params(
|
|
80
116
|
nn_params=None,
|
|
81
|
-
eq_params=
|
|
82
|
-
k: (0 if k in eq_params_batch_dict.keys() else None)
|
|
83
|
-
for k in params.eq_params.keys()
|
|
84
|
-
},
|
|
117
|
+
eq_params=eq_param_vmap_axes,
|
|
85
118
|
),
|
|
86
119
|
)
|
|
87
120
|
return vmap_in_axes_params
|
jinns/solver/_solve.py
CHANGED
|
@@ -14,7 +14,7 @@ import optax
|
|
|
14
14
|
import jax
|
|
15
15
|
from jax import jit
|
|
16
16
|
import jax.numpy as jnp
|
|
17
|
-
from jaxtyping import Float, Array, PyTree,
|
|
17
|
+
from jaxtyping import Float, Array, PyTree, PRNGKeyArray
|
|
18
18
|
import equinox as eqx
|
|
19
19
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
20
20
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
@@ -47,7 +47,7 @@ if TYPE_CHECKING:
|
|
|
47
47
|
LossContainer,
|
|
48
48
|
StoredObjectContainer,
|
|
49
49
|
Float[Array, " n_iter"] | None,
|
|
50
|
-
|
|
50
|
+
PRNGKeyArray | None,
|
|
51
51
|
]
|
|
52
52
|
|
|
53
53
|
|
|
@@ -66,7 +66,7 @@ def solve(
|
|
|
66
66
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
67
67
|
verbose: bool = True,
|
|
68
68
|
ahead_of_time: bool = True,
|
|
69
|
-
key:
|
|
69
|
+
key: PRNGKeyArray | None = None,
|
|
70
70
|
) -> tuple[
|
|
71
71
|
Params[Array],
|
|
72
72
|
Float[Array, " n_iter"],
|
|
@@ -179,6 +179,7 @@ def solve(
|
|
|
179
179
|
best_val_params
|
|
180
180
|
The best parameters according to the validation criterion
|
|
181
181
|
"""
|
|
182
|
+
initialization_time = time.time()
|
|
182
183
|
if n_iter < 1:
|
|
183
184
|
raise ValueError("Cannot run jinns.solve for n_iter<1")
|
|
184
185
|
|
|
@@ -225,11 +226,6 @@ def solve(
|
|
|
225
226
|
# get_batch with device_put, the latter is not jittable
|
|
226
227
|
get_batch = _get_get_batch(obs_batch_sharding)
|
|
227
228
|
|
|
228
|
-
# initialize the dict for stored parameter values
|
|
229
|
-
# we need to get a loss_term to init stuff
|
|
230
|
-
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
231
|
-
_, loss_terms = loss(init_params, batch_ini)
|
|
232
|
-
|
|
233
229
|
# initialize parameter tracking
|
|
234
230
|
if tracked_params is None:
|
|
235
231
|
tracked_params = jax.tree.map(lambda p: None, init_params)
|
|
@@ -247,6 +243,13 @@ def solve(
|
|
|
247
243
|
# being a complex data structure
|
|
248
244
|
)
|
|
249
245
|
|
|
246
|
+
# initialize the dict for stored parameter values
|
|
247
|
+
# we need to get a loss_term to init stuff
|
|
248
|
+
# NOTE: we use jax.eval_shape to avoid FLOPS since we only need the tree
|
|
249
|
+
# structure
|
|
250
|
+
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
251
|
+
_, loss_terms = jax.eval_shape(loss, init_params, batch_ini)
|
|
252
|
+
|
|
250
253
|
# initialize the PyTree for stored loss values
|
|
251
254
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
252
255
|
lambda _: jnp.zeros((n_iter)), loss_terms
|
|
@@ -475,6 +478,9 @@ def solve(
|
|
|
475
478
|
key,
|
|
476
479
|
)
|
|
477
480
|
|
|
481
|
+
if verbose:
|
|
482
|
+
print("Initialization time:", time.time() - initialization_time)
|
|
483
|
+
|
|
478
484
|
# Main optimization loop. We use the LAX while loop (fully jitted) version
|
|
479
485
|
# if no mixing devices. Otherwise we use the standard while loop. Here devices only
|
|
480
486
|
# concern obs_batch, but it could lead to more complex scheme in the future
|