jinns 1.5.1__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +55 -34
- jinns/data/_CubicMeshPDEStatio.py +63 -35
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +86 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +139 -184
- jinns/loss/_LossPDE.py +440 -358
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +50 -25
- jinns/solver/_solve.py +3 -3
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/METADATA +2 -2
- jinns-1.6.1.dist-info/RECORD +57 -0
- jinns-1.5.1.dist-info/RECORD +0 -55
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/WHEEL +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/top_level.txt +0 -0
jinns/nn/_spinn.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Union, Callable, Any, Literal
|
|
2
|
+
from typing import Union, Callable, Any, Literal
|
|
3
3
|
from dataclasses import InitVar
|
|
4
4
|
from jaxtyping import PyTree, Float, Array
|
|
5
5
|
import jax
|
|
@@ -8,7 +8,6 @@ import equinox as eqx
|
|
|
8
8
|
|
|
9
9
|
from jinns.parameters._params import Params
|
|
10
10
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
11
|
-
from jinns.nn._utils import _PyTree_to_Params
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class SPINN(AbstractPINN):
|
|
@@ -72,17 +71,6 @@ class SPINN(AbstractPINN):
|
|
|
72
71
|
eqx_spinn_network, self.filter_spec
|
|
73
72
|
)
|
|
74
73
|
|
|
75
|
-
@overload
|
|
76
|
-
@_PyTree_to_Params
|
|
77
|
-
def __call__(
|
|
78
|
-
self,
|
|
79
|
-
inputs: Float[Array, " input_dim"],
|
|
80
|
-
params: PyTree,
|
|
81
|
-
*args,
|
|
82
|
-
**kwargs,
|
|
83
|
-
) -> Float[Array, " output_dim"]: ...
|
|
84
|
-
|
|
85
|
-
@_PyTree_to_Params
|
|
86
74
|
def __call__(
|
|
87
75
|
self,
|
|
88
76
|
t_x: Float[Array, " batch_size 1+dim"],
|
|
@@ -94,10 +82,7 @@ class SPINN(AbstractPINN):
|
|
|
94
82
|
Note that that thanks to the decorator, params can also directly be the
|
|
95
83
|
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
96
84
|
"""
|
|
97
|
-
# try:
|
|
98
85
|
spinn = eqx.combine(params.nn_params, self.static)
|
|
99
|
-
# except (KeyError, AttributeError, TypeError) as e:
|
|
100
|
-
# spinn = eqx.combine(params, self.static)
|
|
101
86
|
v_model = jax.vmap(spinn)
|
|
102
87
|
res = v_model(t_x) # type: ignore
|
|
103
88
|
|
jinns/nn/_spinn_mlp.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Callable, Literal, Self, Union, Any, TypeGuard
|
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
import equinox as eqx
|
|
11
|
-
from jaxtyping import
|
|
11
|
+
from jaxtyping import PRNGKeyArray, Array, Float, PyTree
|
|
12
12
|
|
|
13
13
|
from jinns.nn._mlp import MLP
|
|
14
14
|
from jinns.nn._spinn import SPINN
|
|
@@ -20,7 +20,7 @@ class SMLP(eqx.Module):
|
|
|
20
20
|
|
|
21
21
|
Parameters
|
|
22
22
|
----------
|
|
23
|
-
key : InitVar[
|
|
23
|
+
key : InitVar[PRNGKeyArray]
|
|
24
24
|
A jax random key for the layer initializations.
|
|
25
25
|
d : int
|
|
26
26
|
The number of dimensions to treat separately, including time `t` if
|
|
@@ -42,7 +42,7 @@ class SMLP(eqx.Module):
|
|
|
42
42
|
)`.
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
|
-
key: InitVar[
|
|
45
|
+
key: InitVar[PRNGKeyArray] = eqx.field(kw_only=True)
|
|
46
46
|
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
47
47
|
eqx.field(kw_only=True)
|
|
48
48
|
)
|
|
@@ -74,7 +74,7 @@ class SPINN_MLP(SPINN):
|
|
|
74
74
|
@classmethod
|
|
75
75
|
def create(
|
|
76
76
|
cls,
|
|
77
|
-
key:
|
|
77
|
+
key: PRNGKeyArray,
|
|
78
78
|
d: int,
|
|
79
79
|
r: int,
|
|
80
80
|
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
@@ -93,7 +93,7 @@ class SPINN_MLP(SPINN):
|
|
|
93
93
|
|
|
94
94
|
Parameters
|
|
95
95
|
----------
|
|
96
|
-
key :
|
|
96
|
+
key : PRNGKeyArray
|
|
97
97
|
A JAX random key that will be used to initialize the network parameters
|
|
98
98
|
d : int
|
|
99
99
|
The number of dimensions to treat separately.
|
jinns/nn/_utils.py
CHANGED
|
@@ -1,38 +1,33 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
*args
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
if isinstance(params, PyTree) and not isinstance(params, Params):
|
|
35
|
-
params = Params(nn_params=params, eq_params={})
|
|
36
|
-
return call_fun(self, inputs, params, *args, **kwargs)
|
|
37
|
-
|
|
38
|
-
return wrapper
|
|
1
|
+
# P = ParamSpec("P")
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# def _PyTree_to_Params(
|
|
5
|
+
# call_fun: Callable[
|
|
6
|
+
# Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
7
|
+
# Any,
|
|
8
|
+
# ],
|
|
9
|
+
# ) -> Callable[
|
|
10
|
+
# Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
11
|
+
# Any,
|
|
12
|
+
# ]:
|
|
13
|
+
# """
|
|
14
|
+
# Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
|
|
15
|
+
# authorizes the __call__ with `params` being directly be the
|
|
16
|
+
# PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
|
|
17
|
+
#
|
|
18
|
+
# This generic approach enables to cleanly handle type hints, up to the small
|
|
19
|
+
# effort required to understand type hints for decorators (ie ParamSpec).
|
|
20
|
+
# """
|
|
21
|
+
#
|
|
22
|
+
# def wrapper(
|
|
23
|
+
# self: Any,
|
|
24
|
+
# inputs: Any,
|
|
25
|
+
# params: PyTree | Params[Array],
|
|
26
|
+
# *args: P.args,
|
|
27
|
+
# **kwargs: P.kwargs,
|
|
28
|
+
# ):
|
|
29
|
+
# if isinstance(params, PyTree) and not isinstance(params, Params):
|
|
30
|
+
# params = Params(nn_params=params, eq_params={})
|
|
31
|
+
# return call_fun(self, inputs, params, *args, **kwargs)
|
|
32
|
+
#
|
|
33
|
+
# return wrapper
|
jinns/parameters/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from ._params import Params
|
|
1
|
+
from ._params import EqParams, Params, update_eq_params
|
|
2
2
|
from ._derivative_keys import (
|
|
3
3
|
DerivativeKeysODE,
|
|
4
4
|
DerivativeKeysPDEStatio,
|
|
@@ -6,8 +6,10 @@ from ._derivative_keys import (
|
|
|
6
6
|
)
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
|
+
"EqParams",
|
|
9
10
|
"Params",
|
|
10
11
|
"DerivativeKeysODE",
|
|
11
12
|
"DerivativeKeysPDEStatio",
|
|
12
13
|
"DerivativeKeysPDENonStatio",
|
|
14
|
+
"update_eq_params",
|
|
13
15
|
]
|
|
@@ -19,13 +19,10 @@ def _get_masked_parameters(
|
|
|
19
19
|
"""
|
|
20
20
|
# start with a params object with True everywhere. We will update to False
|
|
21
21
|
# for parameters wrt which we do want not to differentiate the loss
|
|
22
|
-
diff_params =
|
|
23
|
-
lambda
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
and not isinstance(x, Params), # do not travers nn_params, more
|
|
27
|
-
# granularity could be imagined here, in the future
|
|
28
|
-
)
|
|
22
|
+
diff_params = Params(
|
|
23
|
+
nn_params=True, eq_params=jax.tree.map(lambda _: True, params.eq_params)
|
|
24
|
+
) # do not travers nn_params, more
|
|
25
|
+
# granularity could be imagined here, in the future
|
|
29
26
|
if derivative_mask_str == "both":
|
|
30
27
|
return diff_params
|
|
31
28
|
if derivative_mask_str == "eq_params":
|
|
@@ -60,7 +57,7 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
60
57
|
|
|
61
58
|
1. For unspecified loss term, the default is to differentiate with
|
|
62
59
|
respect to `"nn_params"` only.
|
|
63
|
-
2. No granularity inside `Params.nn_params` is currently supported.
|
|
60
|
+
2. No granularity inside `Params.nn_params` is currently supported. An easy way to do freeze part of a custom PINN module is to use `jax.lax.stop_gradient` as explained [here](https://docs.kidger.site/equinox/faq/#how-to-mark-arrays-as-non-trainable-like-pytorchs-buffers).
|
|
64
61
|
3. Note that the main Params object of the problem is mandatory if initialization via `from_str()`.
|
|
65
62
|
|
|
66
63
|
A typical specification is of the form:
|
|
@@ -95,38 +92,52 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
95
92
|
infer the content of `Params.eq_params`.
|
|
96
93
|
"""
|
|
97
94
|
|
|
98
|
-
dyn_loss: Params[bool]
|
|
99
|
-
observations: Params[bool]
|
|
100
|
-
initial_condition: Params[bool]
|
|
95
|
+
dyn_loss: Params[bool]
|
|
96
|
+
observations: Params[bool]
|
|
97
|
+
initial_condition: Params[bool]
|
|
101
98
|
|
|
102
|
-
params: InitVar[Params[Array] | None]
|
|
99
|
+
params: InitVar[Params[Array] | None]
|
|
103
100
|
|
|
104
|
-
def
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
*,
|
|
104
|
+
dyn_loss: Params[bool] | None = None,
|
|
105
|
+
observations: Params[bool] | None = None,
|
|
106
|
+
initial_condition: Params[bool] | None = None,
|
|
107
|
+
params: Params[Array] | None = None,
|
|
108
|
+
):
|
|
109
|
+
super().__init__()
|
|
105
110
|
if params is None and (
|
|
106
|
-
|
|
107
|
-
or self.observations is None
|
|
108
|
-
or self.initial_condition is None
|
|
111
|
+
dyn_loss is None or observations is None or initial_condition is None
|
|
109
112
|
):
|
|
110
113
|
raise ValueError(
|
|
111
114
|
"params cannot be None since at least one loss "
|
|
112
115
|
"term has an undefined derivative key Params PyTree"
|
|
113
116
|
)
|
|
114
|
-
if
|
|
117
|
+
if dyn_loss is None:
|
|
115
118
|
if params is None:
|
|
116
119
|
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
117
120
|
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
118
|
-
|
|
121
|
+
else:
|
|
122
|
+
self.dyn_loss = dyn_loss
|
|
123
|
+
|
|
124
|
+
if observations is None:
|
|
119
125
|
if params is None:
|
|
120
126
|
raise ValueError(
|
|
121
127
|
"self.observations is None, hence params should be passed"
|
|
122
128
|
)
|
|
123
129
|
self.observations = _get_masked_parameters("nn_params", params)
|
|
124
|
-
|
|
130
|
+
else:
|
|
131
|
+
self.observations = observations
|
|
132
|
+
|
|
133
|
+
if initial_condition is None:
|
|
125
134
|
if params is None:
|
|
126
135
|
raise ValueError(
|
|
127
136
|
"self.initial_condition is None, hence params should be passed"
|
|
128
137
|
)
|
|
129
138
|
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
139
|
+
else:
|
|
140
|
+
self.initial_condition = initial_condition
|
|
130
141
|
|
|
131
142
|
@classmethod
|
|
132
143
|
def from_str(
|
|
@@ -216,36 +227,56 @@ class DerivativeKeysPDEStatio(eqx.Module):
|
|
|
216
227
|
content of `Params.eq_params`.
|
|
217
228
|
"""
|
|
218
229
|
|
|
219
|
-
dyn_loss: Params[bool]
|
|
220
|
-
observations: Params[bool]
|
|
221
|
-
boundary_loss: Params[bool]
|
|
222
|
-
norm_loss: Params[bool]
|
|
230
|
+
dyn_loss: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
231
|
+
observations: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
232
|
+
boundary_loss: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
233
|
+
norm_loss: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
223
234
|
|
|
224
235
|
params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
|
|
225
236
|
|
|
226
|
-
def
|
|
227
|
-
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
*,
|
|
240
|
+
dyn_loss: Params[bool] | None = None,
|
|
241
|
+
observations: Params[bool] | None = None,
|
|
242
|
+
boundary_loss: Params[bool] | None = None,
|
|
243
|
+
norm_loss: Params[bool] | None = None,
|
|
244
|
+
params: Params[Array] | None = None,
|
|
245
|
+
):
|
|
246
|
+
super().__init__()
|
|
247
|
+
if dyn_loss is None:
|
|
228
248
|
if params is None:
|
|
229
249
|
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
230
250
|
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
231
|
-
|
|
251
|
+
else:
|
|
252
|
+
self.dyn_loss = dyn_loss
|
|
253
|
+
|
|
254
|
+
if observations is None:
|
|
232
255
|
if params is None:
|
|
233
256
|
raise ValueError(
|
|
234
257
|
"self.observations is None, hence params should be passed"
|
|
235
258
|
)
|
|
236
259
|
self.observations = _get_masked_parameters("nn_params", params)
|
|
237
|
-
|
|
260
|
+
else:
|
|
261
|
+
self.observations = observations
|
|
262
|
+
|
|
263
|
+
if boundary_loss is None:
|
|
238
264
|
if params is None:
|
|
239
265
|
raise ValueError(
|
|
240
266
|
"self.boundary_loss is None, hence params should be passed"
|
|
241
267
|
)
|
|
242
268
|
self.boundary_loss = _get_masked_parameters("nn_params", params)
|
|
243
|
-
|
|
269
|
+
else:
|
|
270
|
+
self.boundary_loss = boundary_loss
|
|
271
|
+
|
|
272
|
+
if norm_loss is None:
|
|
244
273
|
if params is None:
|
|
245
274
|
raise ValueError(
|
|
246
275
|
"self.norm_loss is None, hence params should be passed"
|
|
247
276
|
)
|
|
248
277
|
self.norm_loss = _get_masked_parameters("nn_params", params)
|
|
278
|
+
else:
|
|
279
|
+
self.norm_loss = norm_loss
|
|
249
280
|
|
|
250
281
|
@classmethod
|
|
251
282
|
def from_str(
|
|
@@ -344,16 +375,33 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
|
344
375
|
content of `Params.eq_params`.
|
|
345
376
|
"""
|
|
346
377
|
|
|
347
|
-
initial_condition: Params[bool]
|
|
348
|
-
|
|
349
|
-
def
|
|
350
|
-
|
|
351
|
-
|
|
378
|
+
initial_condition: Params[bool] = eqx.field(kw_only=True, default=None)
|
|
379
|
+
|
|
380
|
+
def __init__(
|
|
381
|
+
self,
|
|
382
|
+
*,
|
|
383
|
+
dyn_loss: Params[bool] | None = None,
|
|
384
|
+
observations: Params[bool] | None = None,
|
|
385
|
+
boundary_loss: Params[bool] | None = None,
|
|
386
|
+
norm_loss: Params[bool] | None = None,
|
|
387
|
+
initial_condition: Params[bool] | None = None,
|
|
388
|
+
params: Params[Array] | None = None,
|
|
389
|
+
):
|
|
390
|
+
super().__init__(
|
|
391
|
+
dyn_loss=dyn_loss,
|
|
392
|
+
observations=observations,
|
|
393
|
+
boundary_loss=boundary_loss,
|
|
394
|
+
norm_loss=norm_loss,
|
|
395
|
+
params=params,
|
|
396
|
+
)
|
|
397
|
+
if initial_condition is None:
|
|
352
398
|
if params is None:
|
|
353
399
|
raise ValueError(
|
|
354
400
|
"self.initial_condition is None, hence params should be passed"
|
|
355
401
|
)
|
|
356
402
|
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
403
|
+
else:
|
|
404
|
+
self.initial_condition = initial_condition
|
|
357
405
|
|
|
358
406
|
@classmethod
|
|
359
407
|
def from_str(
|
|
@@ -432,7 +480,9 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
|
432
480
|
)
|
|
433
481
|
|
|
434
482
|
|
|
435
|
-
def _set_derivatives(
|
|
483
|
+
def _set_derivatives(
|
|
484
|
+
params: Params[Array], derivative_keys: Params[bool]
|
|
485
|
+
) -> Params[Array]:
|
|
436
486
|
"""
|
|
437
487
|
We construct an eqx.Module with the fields of derivative_keys, each field
|
|
438
488
|
has a copy of the params with appropriate derivatives set
|
|
@@ -448,13 +498,21 @@ def _set_derivatives(params, derivative_keys):
|
|
|
448
498
|
`Params(nn_params=True | False, eq_params={"alpha":True | False,
|
|
449
499
|
"beta":True | False})`.
|
|
450
500
|
"""
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
501
|
+
|
|
502
|
+
return Params(
|
|
503
|
+
nn_params=jax.lax.cond(
|
|
504
|
+
derivative_mask.nn_params,
|
|
505
|
+
lambda p: p,
|
|
506
|
+
jax.lax.stop_gradient,
|
|
507
|
+
params_.nn_params,
|
|
508
|
+
),
|
|
509
|
+
eq_params=jax.tree.map(
|
|
510
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
511
|
+
params_.eq_params,
|
|
512
|
+
derivative_mask.eq_params,
|
|
513
|
+
),
|
|
458
514
|
)
|
|
515
|
+
# NOTE that currently we do not travers nn_params, more
|
|
516
|
+
# granularity could be imagined here, in the future
|
|
459
517
|
|
|
460
518
|
return _set_derivatives_(params, derivative_keys)
|
jinns/parameters/_params.py
CHANGED
|
@@ -2,10 +2,12 @@
|
|
|
2
2
|
Formalize the data structure for the parameters
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from dataclasses import fields
|
|
5
6
|
from typing import Generic, TypeVar
|
|
6
|
-
import jax
|
|
7
7
|
import equinox as eqx
|
|
8
|
-
from jaxtyping import Array, PyTree
|
|
8
|
+
from jaxtyping import Array, PyTree
|
|
9
|
+
|
|
10
|
+
from jinns.utils._DictToModuleMeta import DictToModuleMeta
|
|
9
11
|
|
|
10
12
|
T = TypeVar("T") # the generic type for what is in the Params PyTree because we
|
|
11
13
|
# have possibly Params of Arrays, boolean, ...
|
|
@@ -19,6 +21,16 @@ T = TypeVar("T") # the generic type for what is in the Params PyTree because we
|
|
|
19
21
|
### see https://github.com/patrick-kidger/equinox/pull/1043/commits/f88e62ab809140334c2f987ed13eff0d80b8be13
|
|
20
22
|
|
|
21
23
|
|
|
24
|
+
class EqParams(metaclass=DictToModuleMeta):
|
|
25
|
+
"""
|
|
26
|
+
Note that this is exposed to the user for the particular case where the
|
|
27
|
+
user, during its work, wants to change the equation parameters. In this
|
|
28
|
+
case, the user must import EqParams and call `EqParams.clear()`
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
22
34
|
class Params(eqx.Module, Generic[T]):
|
|
23
35
|
"""
|
|
24
36
|
The equinox module for the parameters
|
|
@@ -28,37 +40,47 @@ class Params(eqx.Module, Generic[T]):
|
|
|
28
40
|
nn_params : PyTree[T]
|
|
29
41
|
A PyTree of the non-static part of the PINN eqx.Module, i.e., the
|
|
30
42
|
parameters of the PINN
|
|
31
|
-
eq_params :
|
|
32
|
-
A
|
|
33
|
-
values are their corresponding
|
|
43
|
+
eq_params : PyTree[T]
|
|
44
|
+
A PyTree of the equation parameters. For retrocompatibility it us
|
|
45
|
+
provided as a dictionary of the equation parameters where keys are the parameter names, and values are their corresponding values. Internally,
|
|
46
|
+
it will be transformed to a custom instance of `EqParams`.
|
|
34
47
|
"""
|
|
35
48
|
|
|
36
|
-
nn_params: PyTree[T]
|
|
37
|
-
eq_params:
|
|
49
|
+
nn_params: PyTree[T]
|
|
50
|
+
eq_params: PyTree[T]
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
nn_params: PyTree[T] | None = None,
|
|
55
|
+
eq_params: dict[str, T] | None = None,
|
|
56
|
+
):
|
|
57
|
+
self.nn_params = nn_params
|
|
58
|
+
if isinstance(eq_params, dict):
|
|
59
|
+
self.eq_params = EqParams(eq_params, "EqParams")
|
|
60
|
+
else:
|
|
61
|
+
self.eq_params = eq_params
|
|
38
62
|
|
|
39
63
|
|
|
40
|
-
def
|
|
64
|
+
def update_eq_params(
|
|
41
65
|
params: Params[Array],
|
|
42
|
-
|
|
43
|
-
) -> Params:
|
|
66
|
+
eq_param_batch: PyTree[Array] | None,
|
|
67
|
+
) -> Params[Array]:
|
|
44
68
|
"""
|
|
45
69
|
Update params.eq_params with a batch of eq_params for given key(s)
|
|
46
70
|
"""
|
|
47
71
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
param_batch_dict_ = param_batch_dict | {
|
|
51
|
-
k: None for k in set(params.eq_params.keys()) - set(param_batch_dict.keys())
|
|
52
|
-
}
|
|
72
|
+
if eq_param_batch is None:
|
|
73
|
+
return params
|
|
53
74
|
|
|
54
|
-
|
|
75
|
+
param_names_to_update = tuple(f.name for f in fields(eq_param_batch))
|
|
55
76
|
params = eqx.tree_at(
|
|
56
77
|
lambda p: p.eq_params,
|
|
57
78
|
params,
|
|
58
|
-
|
|
59
|
-
lambda
|
|
79
|
+
eqx.tree_at(
|
|
80
|
+
lambda pt: tuple(getattr(pt, f) for f in param_names_to_update),
|
|
60
81
|
params.eq_params,
|
|
61
|
-
|
|
82
|
+
tuple(getattr(eq_param_batch, f) for f in param_names_to_update),
|
|
83
|
+
is_leaf=lambda x: x is None or eqx.is_inexact_array(x),
|
|
62
84
|
),
|
|
63
85
|
)
|
|
64
86
|
|
|
@@ -66,7 +88,7 @@ def _update_eq_params_dict(
|
|
|
66
88
|
|
|
67
89
|
|
|
68
90
|
def _get_vmap_in_axes_params(
|
|
69
|
-
|
|
91
|
+
eq_param_batch: eqx.Module | None, params: Params[Array]
|
|
70
92
|
) -> tuple[Params[int | None] | None]:
|
|
71
93
|
"""
|
|
72
94
|
Return the input vmap axes when there is batch(es) of parameters to vmap
|
|
@@ -77,19 +99,22 @@ def _get_vmap_in_axes_params(
|
|
|
77
99
|
Note that we return a Params PyTree with an integer to designate the
|
|
78
100
|
vmapped axis or None if there is not
|
|
79
101
|
"""
|
|
80
|
-
if
|
|
102
|
+
if eq_param_batch is None:
|
|
81
103
|
return (None,)
|
|
82
104
|
# We use pytree indexing of vmapped axes and vmap on axis
|
|
83
105
|
# 0 of the eq_parameters for which we have a batch
|
|
84
106
|
# this is for a fine-grained vmaping
|
|
85
107
|
# scheme over the params
|
|
108
|
+
param_names_to_vmap = tuple(f.name for f in fields(eq_param_batch))
|
|
109
|
+
vmap_axes_dict = {
|
|
110
|
+
k.name: (0 if k.name in param_names_to_vmap else None)
|
|
111
|
+
for k in fields(params.eq_params)
|
|
112
|
+
}
|
|
113
|
+
eq_param_vmap_axes = type(params.eq_params)(**vmap_axes_dict)
|
|
86
114
|
vmap_in_axes_params = (
|
|
87
115
|
Params(
|
|
88
116
|
nn_params=None,
|
|
89
|
-
eq_params=
|
|
90
|
-
k: (0 if k in eq_params_batch_dict.keys() else None)
|
|
91
|
-
for k in params.eq_params.keys()
|
|
92
|
-
},
|
|
117
|
+
eq_params=eq_param_vmap_axes,
|
|
93
118
|
),
|
|
94
119
|
)
|
|
95
120
|
return vmap_in_axes_params
|
jinns/solver/_solve.py
CHANGED
|
@@ -14,7 +14,7 @@ import optax
|
|
|
14
14
|
import jax
|
|
15
15
|
from jax import jit
|
|
16
16
|
import jax.numpy as jnp
|
|
17
|
-
from jaxtyping import Float, Array, PyTree,
|
|
17
|
+
from jaxtyping import Float, Array, PyTree, PRNGKeyArray
|
|
18
18
|
import equinox as eqx
|
|
19
19
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
20
20
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
@@ -47,7 +47,7 @@ if TYPE_CHECKING:
|
|
|
47
47
|
LossContainer,
|
|
48
48
|
StoredObjectContainer,
|
|
49
49
|
Float[Array, " n_iter"] | None,
|
|
50
|
-
|
|
50
|
+
PRNGKeyArray | None,
|
|
51
51
|
]
|
|
52
52
|
|
|
53
53
|
|
|
@@ -66,7 +66,7 @@ def solve(
|
|
|
66
66
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
67
67
|
verbose: bool = True,
|
|
68
68
|
ahead_of_time: bool = True,
|
|
69
|
-
key:
|
|
69
|
+
key: PRNGKeyArray | None = None,
|
|
70
70
|
) -> tuple[
|
|
71
71
|
Params[Array],
|
|
72
72
|
Float[Array, " n_iter"],
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DictToModuleMeta(type):
|
|
6
|
+
"""
|
|
7
|
+
A Metaclass based solution to handle the fact that we only
|
|
8
|
+
want one type to be created for EqParams.
|
|
9
|
+
If we were to create a new **class type** (despite same name) each time we
|
|
10
|
+
create a new Params object, nothing would be broadcastable in terms of jax
|
|
11
|
+
tree utils operations and this would be useless. The difficulty comes from
|
|
12
|
+
the fact that we need to instanciate from this same class at different
|
|
13
|
+
moments of the jinns workflow eg: parameter creation, derivative keys
|
|
14
|
+
creations, tracked parameter designation, etc. (ie. each time a Params
|
|
15
|
+
class is instanciated whatever its usage, we need the same EqParams class
|
|
16
|
+
to be instanciated)
|
|
17
|
+
|
|
18
|
+
This is inspired by the Singleton pattern in Python
|
|
19
|
+
(https://stackoverflow.com/a/10362179)
|
|
20
|
+
|
|
21
|
+
Here we need the call of a metaclass because as explained in
|
|
22
|
+
https://stackoverflow.com/a/45536640). To quote from the answer
|
|
23
|
+
Metaclasses implement how the class will behave (not the instance). So when you look at the instance creation:
|
|
24
|
+
`x = Foo()`
|
|
25
|
+
This literally "calls" the class Foo. That's why __call__ of the metaclass
|
|
26
|
+
is invoked before the __new__ and
|
|
27
|
+
__init__ methods of your class initialize the instance.
|
|
28
|
+
Other viewpoint: Metaclasses,as well as classes making use of those
|
|
29
|
+
metaclasses, are created when the lines of code containing
|
|
30
|
+
the class statement body is executed
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, *args, **kwargs):
|
|
34
|
+
super(DictToModuleMeta, self).__init__(*args, **kwargs)
|
|
35
|
+
self._class = None
|
|
36
|
+
|
|
37
|
+
def __call__(self, d: dict[str, Any], class_name: str | None = None) -> eqx.Module:
|
|
38
|
+
"""
|
|
39
|
+
Notably, once the class template is registered (after the first call to
|
|
40
|
+
EqParams()), all calls with different keys in `d` will fail.
|
|
41
|
+
"""
|
|
42
|
+
if self._class is None and class_name is not None:
|
|
43
|
+
self._class = type(
|
|
44
|
+
class_name,
|
|
45
|
+
(eqx.Module,),
|
|
46
|
+
{"__annotations__": {k: type(v) for k, v in d.items()}},
|
|
47
|
+
)
|
|
48
|
+
try:
|
|
49
|
+
return self._class(**d) # type: ignore
|
|
50
|
+
except TypeError as _:
|
|
51
|
+
print(
|
|
52
|
+
"DictToModuleMeta has been created with the fields"
|
|
53
|
+
f"{tuple(k for k in self._class.__annotations__.keys())}"
|
|
54
|
+
f" but an instanciation is resquested with fields={tuple(k for k in d.keys())}"
|
|
55
|
+
" which results in an error"
|
|
56
|
+
)
|
|
57
|
+
raise ValueError
|
|
58
|
+
|
|
59
|
+
def clear(cls) -> None:
|
|
60
|
+
"""
|
|
61
|
+
The current Metaclass implementation freezes the list of equation parameters inside a Python session;
|
|
62
|
+
only one EqParams annotation can exist at a given time. Use `EqParams.clear()` to reset.
|
|
63
|
+
Also useful for pytest where stuff is not complety reset after tests
|
|
64
|
+
Taken from https://stackoverflow.com/a/50065732
|
|
65
|
+
"""
|
|
66
|
+
cls._class = None
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from dataclasses import fields
|
|
2
|
+
from typing import Any, ItemsView
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ItemizableModule(eqx.Module):
|
|
7
|
+
def items(self) -> ItemsView[str, Any]:
|
|
8
|
+
"""
|
|
9
|
+
For the dataclass to be iterated like a dictionary.
|
|
10
|
+
Practical and retrocompatible with old code when loss components were
|
|
11
|
+
dictionaries
|
|
12
|
+
|
|
13
|
+
About the type hint: https://stackoverflow.com/questions/73022688/type-annotation-for-dict-items
|
|
14
|
+
"""
|
|
15
|
+
return {
|
|
16
|
+
field.name: getattr(self, field.name)
|
|
17
|
+
for field in fields(self)
|
|
18
|
+
if getattr(self, field.name) is not None
|
|
19
|
+
}.items()
|
jinns/utils/__init__.py
CHANGED