jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/nn/_ppinn.py
CHANGED
|
@@ -2,17 +2,20 @@
|
|
|
2
2
|
Implements utility function to create PINNs
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Literal, Self, cast, overload
|
|
6
8
|
from dataclasses import InitVar
|
|
7
9
|
import jax
|
|
8
10
|
import jax.numpy as jnp
|
|
9
11
|
import equinox as eqx
|
|
10
12
|
|
|
11
|
-
from jaxtyping import Array, Key,
|
|
13
|
+
from jaxtyping import Array, Key, Float, PyTree
|
|
12
14
|
|
|
13
|
-
from jinns.parameters._params import Params
|
|
15
|
+
from jinns.parameters._params import Params
|
|
14
16
|
from jinns.nn._pinn import PINN
|
|
15
17
|
from jinns.nn._mlp import MLP
|
|
18
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
class PPINN_MLP(PINN):
|
|
@@ -39,12 +42,12 @@ class PPINN_MLP(PINN):
|
|
|
39
42
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
40
43
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
41
44
|
after the `input_transform` function.
|
|
42
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
45
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
43
46
|
A function that will be called before entering the PPINN. Its output(s)
|
|
44
47
|
must match the PPINN inputs (except for the parameters).
|
|
45
48
|
Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
|
|
46
49
|
and the parameters. Default is no operation.
|
|
47
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
50
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
48
51
|
A function with arguments begin the same input as the PPINN, the PPINN
|
|
49
52
|
output and the parameter. This function will be called after exiting
|
|
50
53
|
the PPINN.
|
|
@@ -63,25 +66,46 @@ class PPINN_MLP(PINN):
|
|
|
63
66
|
"""
|
|
64
67
|
|
|
65
68
|
eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
|
|
69
|
+
init_params: tuple[PINN, ...] = eqx.field(
|
|
70
|
+
init=False
|
|
71
|
+
) # overriding parent attribute type
|
|
72
|
+
static: tuple[PINN, ...] = eqx.field(
|
|
73
|
+
init=False, static=True
|
|
74
|
+
) # overriding parent attribute type
|
|
66
75
|
|
|
67
76
|
def __post_init__(self, eqx_network, eqx_network_list):
|
|
68
77
|
super().__post_init__(
|
|
69
78
|
eqx_network=eqx_network_list[0], # this is not used since it is
|
|
70
79
|
# overwritten just below
|
|
71
80
|
)
|
|
72
|
-
|
|
73
|
-
|
|
81
|
+
params, static = eqx.partition(eqx_network_list[0], self.filter_spec)
|
|
82
|
+
self.init_params, self.static = (params,), (static,)
|
|
83
|
+
for eqx_network_ in eqx_network_list[1:]:
|
|
74
84
|
params, static = eqx.partition(eqx_network_, self.filter_spec)
|
|
75
85
|
self.init_params = self.init_params + (params,)
|
|
76
86
|
self.static = self.static + (static,)
|
|
77
87
|
|
|
88
|
+
@overload
|
|
89
|
+
@_PyTree_to_Params
|
|
78
90
|
def __call__(
|
|
79
91
|
self,
|
|
80
|
-
inputs: Float[Array, "
|
|
92
|
+
inputs: Float[Array, " input_dim"],
|
|
81
93
|
params: PyTree,
|
|
82
|
-
|
|
94
|
+
*args,
|
|
95
|
+
**kwargs,
|
|
96
|
+
) -> Float[Array, " output_dim"]: ...
|
|
97
|
+
|
|
98
|
+
@_PyTree_to_Params
|
|
99
|
+
def __call__(
|
|
100
|
+
self,
|
|
101
|
+
inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
102
|
+
params: Params[Array],
|
|
103
|
+
) -> Float[Array, " output_dim"]:
|
|
83
104
|
"""
|
|
84
105
|
Evaluate the PPINN on some inputs with some params.
|
|
106
|
+
|
|
107
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
108
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
85
109
|
"""
|
|
86
110
|
if len(inputs.shape) == 0:
|
|
87
111
|
# This can happen often when the user directly provides some
|
|
@@ -92,14 +116,14 @@ class PPINN_MLP(PINN):
|
|
|
92
116
|
|
|
93
117
|
outs = []
|
|
94
118
|
|
|
95
|
-
try:
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
except (KeyError, AttributeError, TypeError) as e:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
119
|
+
# try:
|
|
120
|
+
for params_, static in zip(params.nn_params, self.static):
|
|
121
|
+
model = eqx.combine(params_, static)
|
|
122
|
+
outs += [model(transformed_inputs)] # type: ignore
|
|
123
|
+
# except (KeyError, AttributeError, TypeError) as e:
|
|
124
|
+
# for params_, static in zip(params, self.static):
|
|
125
|
+
# model = eqx.combine(params_, static)
|
|
126
|
+
# outs += [model(transformed_inputs)]
|
|
103
127
|
# Note that below is then a global output transform
|
|
104
128
|
res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
|
|
105
129
|
|
|
@@ -112,18 +136,31 @@ class PPINN_MLP(PINN):
|
|
|
112
136
|
def create(
|
|
113
137
|
cls,
|
|
114
138
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
115
|
-
eqx_network_list: list[eqx.nn.MLP] = None,
|
|
139
|
+
eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
|
|
116
140
|
key: Key = None,
|
|
117
|
-
eqx_list_list:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
141
|
+
eqx_list_list: (
|
|
142
|
+
list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
|
|
143
|
+
) = None,
|
|
144
|
+
input_transform: (
|
|
145
|
+
Callable[
|
|
146
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
147
|
+
Float[Array, " output_dim"],
|
|
148
|
+
]
|
|
149
|
+
| None
|
|
150
|
+
) = None,
|
|
151
|
+
output_transform: (
|
|
152
|
+
Callable[
|
|
153
|
+
[
|
|
154
|
+
Float[Array, " input_dim"],
|
|
155
|
+
Float[Array, " output_dim"],
|
|
156
|
+
Params[Array],
|
|
157
|
+
],
|
|
158
|
+
Float[Array, " output_dim"],
|
|
159
|
+
]
|
|
160
|
+
| None
|
|
161
|
+
) = None,
|
|
162
|
+
slice_solution: slice | None = None,
|
|
163
|
+
) -> tuple[Self, tuple[PINN, ...]]:
|
|
127
164
|
r"""
|
|
128
165
|
Utility function to create a Parrallel PINN neural network for Jinns.
|
|
129
166
|
|
|
@@ -189,15 +226,14 @@ class PPINN_MLP(PINN):
|
|
|
189
226
|
eqx_network_list = []
|
|
190
227
|
for eqx_list in eqx_list_list:
|
|
191
228
|
key, subkey = jax.random.split(key, 2)
|
|
192
|
-
print(subkey)
|
|
193
229
|
eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
|
|
194
230
|
|
|
195
231
|
ppinn = cls(
|
|
196
|
-
eqx_network=None,
|
|
197
|
-
eqx_network_list=eqx_network_list,
|
|
198
|
-
slice_solution=slice_solution,
|
|
232
|
+
eqx_network=None, # type: ignore
|
|
233
|
+
eqx_network_list=cast(list[eqx.Module], eqx_network_list),
|
|
234
|
+
slice_solution=slice_solution, # type: ignore
|
|
199
235
|
eq_type=eq_type,
|
|
200
|
-
input_transform=input_transform,
|
|
201
|
-
output_transform=output_transform,
|
|
236
|
+
input_transform=input_transform, # type: ignore
|
|
237
|
+
output_transform=output_transform, # type: ignore
|
|
202
238
|
)
|
|
203
239
|
return ppinn, ppinn.init_params
|
jinns/nn/_save_load.py
CHANGED
|
@@ -12,7 +12,7 @@ from jinns.nn._spinn import SPINN
|
|
|
12
12
|
from jinns.nn._mlp import PINN_MLP
|
|
13
13
|
from jinns.nn._spinn_mlp import SPINN_MLP
|
|
14
14
|
from jinns.nn._hyperpinn import HyperPINN
|
|
15
|
-
from jinns.parameters._params import Params
|
|
15
|
+
from jinns.parameters._params import Params
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def function_to_string(
|
|
@@ -87,7 +87,7 @@ def string_to_function(
|
|
|
87
87
|
def save_pinn(
|
|
88
88
|
filename: str,
|
|
89
89
|
u: PINN | HyperPINN | SPINN,
|
|
90
|
-
params: Params
|
|
90
|
+
params: Params,
|
|
91
91
|
kwargs_creation,
|
|
92
92
|
):
|
|
93
93
|
"""
|
|
@@ -105,15 +105,12 @@ def save_pinn(
|
|
|
105
105
|
tree_serialise_leaves`).
|
|
106
106
|
|
|
107
107
|
Equation parameters are saved apart because the initial type of attribute
|
|
108
|
-
`params` in PINN / HyperPINN / SPINN is not `Params`
|
|
108
|
+
`params` in PINN / HyperPINN / SPINN is not `Params`
|
|
109
109
|
but `PyTree` as inherited from `eqx.partition`.
|
|
110
110
|
Therefore, if we want to ensure a proper serialization/deserialization:
|
|
111
111
|
- we cannot save a `Params` object at this
|
|
112
112
|
attribute field ; the `Params` object must be split into `Params.nn_params`
|
|
113
113
|
(type `PyTree`) and `Params.eq_params` (type `dict`).
|
|
114
|
-
- in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
|
|
115
|
-
the attribute field `params` because it is not a `PyTree` (as expected in
|
|
116
|
-
the PINN / HyperPINN / SPINN signature) but it is still a dictionary.
|
|
117
114
|
|
|
118
115
|
Parameters
|
|
119
116
|
----------
|
|
@@ -122,28 +119,16 @@ def save_pinn(
|
|
|
122
119
|
u
|
|
123
120
|
The PINN
|
|
124
121
|
params
|
|
125
|
-
Params
|
|
122
|
+
Params to be saved
|
|
126
123
|
kwargs_creation
|
|
127
124
|
The dictionary of arguments that were used to create the PINN, e.g.
|
|
128
125
|
the layers list, O/PDE type, etc.
|
|
129
126
|
"""
|
|
130
|
-
if isinstance(
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
136
|
-
|
|
137
|
-
elif isinstance(params, ParamsDict):
|
|
138
|
-
for key, params_ in params.nn_params.items():
|
|
139
|
-
if isinstance(u, HyperPINN):
|
|
140
|
-
u = eqx.tree_at(lambda m: m.init_params_hyper, u, params_)
|
|
141
|
-
elif isinstance(u, (PINN, SPINN)):
|
|
142
|
-
u = eqx.tree_at(lambda m: m.init_params, u, params_)
|
|
143
|
-
eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
|
|
144
|
-
|
|
145
|
-
else:
|
|
146
|
-
raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
|
|
127
|
+
if isinstance(u, HyperPINN):
|
|
128
|
+
u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
|
|
129
|
+
elif isinstance(u, (PINN, SPINN)):
|
|
130
|
+
u = eqx.tree_at(lambda m: m.init_params, u, params)
|
|
131
|
+
eqx.tree_serialise_leaves(filename + "-module.eqx", u)
|
|
147
132
|
|
|
148
133
|
with open(filename + "-eq_params.pkl", "wb") as f:
|
|
149
134
|
pickle.dump(params.eq_params, f)
|
|
@@ -170,8 +155,7 @@ def save_pinn(
|
|
|
170
155
|
def load_pinn(
|
|
171
156
|
filename: str,
|
|
172
157
|
type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
|
|
173
|
-
|
|
174
|
-
) -> tuple[eqx.Module, Params | ParamsDict]:
|
|
158
|
+
) -> tuple[eqx.Module, Params]:
|
|
175
159
|
"""
|
|
176
160
|
Load a PINN model. This function needs to access 3 files :
|
|
177
161
|
`{filename}-module.eqx`, `{filename}-parameters.pkl` and
|
|
@@ -190,8 +174,6 @@ def load_pinn(
|
|
|
190
174
|
Filename (prefix) without extension.
|
|
191
175
|
type_
|
|
192
176
|
Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
|
|
193
|
-
key_list_for_paramsdict
|
|
194
|
-
Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
|
|
195
177
|
|
|
196
178
|
Returns
|
|
197
179
|
-------
|
|
@@ -228,29 +210,17 @@ def load_pinn(
|
|
|
228
210
|
)
|
|
229
211
|
else:
|
|
230
212
|
raise ValueError(f"{type_} is not valid")
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
213
|
+
# now the empty structure is populated with the actual saved array values
|
|
214
|
+
# stored in the eqx file
|
|
215
|
+
u_reloaded = eqx.tree_deserialise_leaves(
|
|
216
|
+
filename + "-module.eqx", u_reloaded_shallow
|
|
217
|
+
)
|
|
218
|
+
if isinstance(u_reloaded, HyperPINN):
|
|
219
|
+
params = Params(
|
|
220
|
+
nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
|
|
236
221
|
)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
|
|
240
|
-
)
|
|
241
|
-
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
242
|
-
params = Params(
|
|
243
|
-
nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded
|
|
244
|
-
)
|
|
222
|
+
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
223
|
+
params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
|
|
245
224
|
else:
|
|
246
|
-
|
|
247
|
-
for key in key_list_for_paramsdict:
|
|
248
|
-
u_reloaded = eqx.tree_deserialise_leaves(
|
|
249
|
-
filename + f"-module_{key}.eqx", u_reloaded_shallow
|
|
250
|
-
)
|
|
251
|
-
if isinstance(u_reloaded, HyperPINN):
|
|
252
|
-
nn_params_dict[key] = u_reloaded.init_params_hyper
|
|
253
|
-
elif isinstance(u_reloaded, (PINN, SPINN)):
|
|
254
|
-
nn_params_dict[key] = u_reloaded.init_params
|
|
255
|
-
params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
|
|
225
|
+
raise ValueError("Wrong type for u_reloaded")
|
|
256
226
|
return u_reloaded, params
|
jinns/nn/_spinn.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Union, Callable, Any, Literal, overload
|
|
2
3
|
from dataclasses import InitVar
|
|
3
4
|
from jaxtyping import PyTree, Float, Array
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
import equinox as eqx
|
|
7
8
|
|
|
8
|
-
from jinns.parameters._params import Params
|
|
9
|
+
from jinns.parameters._params import Params
|
|
10
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
11
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class SPINN(
|
|
14
|
+
class SPINN(AbstractPINN):
|
|
12
15
|
"""
|
|
13
16
|
A Separable PINN object compatible with the rest of jinns.
|
|
14
17
|
|
|
@@ -47,21 +50,21 @@ class SPINN(eqx.Module):
|
|
|
47
50
|
|
|
48
51
|
"""
|
|
49
52
|
|
|
53
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
54
|
+
static=True, kw_only=True
|
|
55
|
+
)
|
|
50
56
|
d: int = eqx.field(static=True, kw_only=True)
|
|
51
57
|
r: int = eqx.field(static=True, kw_only=True)
|
|
52
|
-
eq_type: str = eqx.field(static=True, kw_only=True)
|
|
53
58
|
m: int = eqx.field(static=True, kw_only=True, default=1)
|
|
54
|
-
|
|
55
59
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
|
|
56
60
|
static=True, kw_only=True, default=None
|
|
57
61
|
)
|
|
58
62
|
eqx_spinn_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
59
63
|
|
|
60
|
-
init_params:
|
|
61
|
-
static:
|
|
64
|
+
init_params: SPINN = eqx.field(init=False)
|
|
65
|
+
static: SPINN = eqx.field(init=False, static=True)
|
|
62
66
|
|
|
63
67
|
def __post_init__(self, eqx_spinn_network):
|
|
64
|
-
|
|
65
68
|
if self.filter_spec is None:
|
|
66
69
|
self.filter_spec = eqx.is_inexact_array
|
|
67
70
|
|
|
@@ -69,20 +72,34 @@ class SPINN(eqx.Module):
|
|
|
69
72
|
eqx_spinn_network, self.filter_spec
|
|
70
73
|
)
|
|
71
74
|
|
|
75
|
+
@overload
|
|
76
|
+
@_PyTree_to_Params
|
|
72
77
|
def __call__(
|
|
73
78
|
self,
|
|
74
|
-
|
|
75
|
-
params:
|
|
76
|
-
|
|
79
|
+
inputs: Float[Array, " input_dim"],
|
|
80
|
+
params: PyTree,
|
|
81
|
+
*args,
|
|
82
|
+
**kwargs,
|
|
83
|
+
) -> Float[Array, " output_dim"]: ...
|
|
84
|
+
|
|
85
|
+
@_PyTree_to_Params
|
|
86
|
+
def __call__(
|
|
87
|
+
self,
|
|
88
|
+
t_x: Float[Array, " batch_size 1+dim"],
|
|
89
|
+
params: Params[Array],
|
|
90
|
+
) -> Float[Array, " output_dim"]:
|
|
77
91
|
"""
|
|
78
92
|
Evaluate the SPINN on some inputs with some params.
|
|
93
|
+
|
|
94
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
95
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
79
96
|
"""
|
|
80
|
-
try:
|
|
81
|
-
|
|
82
|
-
except (KeyError, AttributeError, TypeError) as e:
|
|
83
|
-
|
|
97
|
+
# try:
|
|
98
|
+
spinn = eqx.combine(params.nn_params, self.static)
|
|
99
|
+
# except (KeyError, AttributeError, TypeError) as e:
|
|
100
|
+
# spinn = eqx.combine(params, self.static)
|
|
84
101
|
v_model = jax.vmap(spinn)
|
|
85
|
-
res = v_model(t_x)
|
|
102
|
+
res = v_model(t_x) # type: ignore
|
|
86
103
|
|
|
87
104
|
a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
|
|
88
105
|
b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
|
jinns/nn/_spinn_mlp.py
CHANGED
|
@@ -4,13 +4,12 @@ https://arxiv.org/abs/2211.08761
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from dataclasses import InitVar
|
|
7
|
-
from typing import Callable, Literal, Self, Union, Any
|
|
7
|
+
from typing import Callable, Literal, Self, Union, Any, TypeGuard
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
import equinox as eqx
|
|
11
11
|
from jaxtyping import Key, Array, Float, PyTree
|
|
12
12
|
|
|
13
|
-
from jinns.parameters._params import Params, ParamsDict
|
|
14
13
|
from jinns.nn._mlp import MLP
|
|
15
14
|
from jinns.nn._spinn import SPINN
|
|
16
15
|
|
|
@@ -26,7 +25,7 @@ class SMLP(eqx.Module):
|
|
|
26
25
|
d : int
|
|
27
26
|
The number of dimensions to treat separately, including time `t` if
|
|
28
27
|
used for non-stationnary equations.
|
|
29
|
-
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
28
|
+
eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
|
|
30
29
|
A tuple of tuples of successive equinox modules and activation functions to
|
|
31
30
|
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
32
31
|
activation function as first item, other items represents arguments
|
|
@@ -34,18 +33,18 @@ class SMLP(eqx.Module):
|
|
|
34
33
|
The `key` argument need not be given.
|
|
35
34
|
Thus typical example is `eqx_list=
|
|
36
35
|
((eqx.nn.Linear, 1, 20),
|
|
37
|
-
jax.nn.tanh,
|
|
36
|
+
(jax.nn.tanh,),
|
|
38
37
|
(eqx.nn.Linear, 20, 20),
|
|
39
|
-
jax.nn.tanh,
|
|
38
|
+
(jax.nn.tanh,),
|
|
40
39
|
(eqx.nn.Linear, 20, 20),
|
|
41
|
-
jax.nn.tanh,
|
|
40
|
+
(jax.nn.tanh,),
|
|
42
41
|
(eqx.nn.Linear, 20, r * m)
|
|
43
42
|
)`.
|
|
44
43
|
"""
|
|
45
44
|
|
|
46
45
|
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
47
|
-
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] =
|
|
48
|
-
kw_only=True
|
|
46
|
+
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
47
|
+
eqx.field(kw_only=True)
|
|
49
48
|
)
|
|
50
49
|
d: int = eqx.field(static=True, kw_only=True)
|
|
51
50
|
|
|
@@ -58,8 +57,8 @@ class SMLP(eqx.Module):
|
|
|
58
57
|
]
|
|
59
58
|
|
|
60
59
|
def __call__(
|
|
61
|
-
self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
|
|
62
|
-
) -> Float[Array, "d embed_dim*output_dim"]:
|
|
60
|
+
self, inputs: Float[Array, " dim"] | Float[Array, " dim+1"]
|
|
61
|
+
) -> Float[Array, " d embed_dim*output_dim"]:
|
|
63
62
|
outputs = []
|
|
64
63
|
for d in range(self.d):
|
|
65
64
|
x_i = inputs[d : d + 1]
|
|
@@ -78,11 +77,11 @@ class SPINN_MLP(SPINN):
|
|
|
78
77
|
key: Key,
|
|
79
78
|
d: int,
|
|
80
79
|
r: int,
|
|
81
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
80
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
82
81
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
83
82
|
m: int = 1,
|
|
84
83
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
85
|
-
) -> tuple[Self,
|
|
84
|
+
) -> tuple[Self, SPINN]:
|
|
86
85
|
"""
|
|
87
86
|
Utility function to create a SPINN neural network with the equinox
|
|
88
87
|
library.
|
|
@@ -108,11 +107,11 @@ class SPINN_MLP(SPINN):
|
|
|
108
107
|
The `key` argument need not be given.
|
|
109
108
|
Thus typical example is
|
|
110
109
|
`eqx_list=((eqx.nn.Linear, 1, 20),
|
|
111
|
-
jax.nn.tanh,
|
|
110
|
+
(jax.nn.tanh,),
|
|
111
|
+
(eqx.nn.Linea)r, 20, 20),
|
|
112
|
+
(jax.nn.tanh,),
|
|
112
113
|
(eqx.nn.Linear, 20, 20),
|
|
113
|
-
jax.nn.tanh,
|
|
114
|
-
(eqx.nn.Linear, 20, 20),
|
|
115
|
-
jax.nn.tanh,
|
|
114
|
+
(jax.nn.tanh,),
|
|
116
115
|
(eqx.nn.Linear, 20, r * m)
|
|
117
116
|
)`.
|
|
118
117
|
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
@@ -158,24 +157,31 @@ class SPINN_MLP(SPINN):
|
|
|
158
157
|
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
159
158
|
raise RuntimeError("Wrong parameter value for eq_type")
|
|
160
159
|
|
|
161
|
-
|
|
160
|
+
def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
|
|
161
|
+
return len(element) > 1
|
|
162
|
+
|
|
163
|
+
if element_is_layer(eqx_list[0]):
|
|
162
164
|
nb_inputs_declared = eqx_list[0][
|
|
163
165
|
1
|
|
164
166
|
] # normally we look for 2nd ele of 1st layer
|
|
165
|
-
|
|
167
|
+
elif element_is_layer(eqx_list[1]):
|
|
166
168
|
nb_inputs_declared = eqx_list[1][
|
|
167
169
|
1
|
|
168
170
|
] # but we can have, eg, a flatten first layer
|
|
169
|
-
|
|
171
|
+
else:
|
|
172
|
+
nb_inputs_declared = None
|
|
173
|
+
if nb_inputs_declared is None or nb_inputs_declared != 1:
|
|
170
174
|
raise ValueError("Input dim must be set to 1 in SPINN!")
|
|
171
175
|
|
|
172
|
-
|
|
176
|
+
if element_is_layer(eqx_list[-1]):
|
|
173
177
|
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
174
178
|
# last layer
|
|
175
|
-
|
|
179
|
+
elif element_is_layer(eqx_list[-2]):
|
|
176
180
|
nb_outputs_declared = eqx_list[-2][2]
|
|
177
181
|
# but we can have, eg, a `jnp.exp` last layer
|
|
178
|
-
|
|
182
|
+
else:
|
|
183
|
+
nb_outputs_declared = None
|
|
184
|
+
if nb_outputs_declared is None or nb_outputs_declared != r * m:
|
|
179
185
|
raise ValueError("Output dim must be set to r * m in SPINN!")
|
|
180
186
|
|
|
181
187
|
if d > 24:
|
jinns/nn/_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Any, ParamSpec, Callable, Concatenate
|
|
2
|
+
from jaxtyping import PyTree, Array
|
|
3
|
+
from jinns.parameters._params import Params
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
P = ParamSpec("P")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _PyTree_to_Params(
|
|
10
|
+
call_fun: Callable[
|
|
11
|
+
Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
12
|
+
Any,
|
|
13
|
+
],
|
|
14
|
+
) -> Callable[
|
|
15
|
+
Concatenate[Any, Any, PyTree | Params[Array], P],
|
|
16
|
+
Any,
|
|
17
|
+
]:
|
|
18
|
+
"""
|
|
19
|
+
Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
|
|
20
|
+
authorizes the __call__ with `params` being directly be the
|
|
21
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
|
|
22
|
+
|
|
23
|
+
This generic approach enables to cleanly handle type hints, up to the small
|
|
24
|
+
effort required to understand type hints for decorators (ie ParamSpec).
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def wrapper(
|
|
28
|
+
self: Any,
|
|
29
|
+
inputs: Any,
|
|
30
|
+
params: PyTree | Params[Array],
|
|
31
|
+
*args: P.args,
|
|
32
|
+
**kwargs: P.kwargs,
|
|
33
|
+
):
|
|
34
|
+
if isinstance(params, PyTree) and not isinstance(params, Params):
|
|
35
|
+
params = Params(nn_params=params, eq_params={})
|
|
36
|
+
return call_fun(self, inputs, params, *args, **kwargs)
|
|
37
|
+
|
|
38
|
+
return wrapper
|
jinns/parameters/__init__.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
-
from ._params import Params
|
|
1
|
+
from ._params import Params
|
|
2
2
|
from ._derivative_keys import (
|
|
3
3
|
DerivativeKeysODE,
|
|
4
4
|
DerivativeKeysPDEStatio,
|
|
5
5
|
DerivativeKeysPDENonStatio,
|
|
6
6
|
)
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Params",
|
|
10
|
+
"DerivativeKeysODE",
|
|
11
|
+
"DerivativeKeysPDEStatio",
|
|
12
|
+
"DerivativeKeysPDENonStatio",
|
|
13
|
+
]
|