jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/nn/_hyperpinn.py
CHANGED
|
@@ -3,9 +3,11 @@ Implements utility function to create HyperPINNs
|
|
|
3
3
|
https://arxiv.org/pdf/2111.01008.pdf
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
6
8
|
import warnings
|
|
7
9
|
from dataclasses import InitVar
|
|
8
|
-
from typing import Callable, Literal, Self, Union, Any
|
|
10
|
+
from typing import Callable, Literal, Self, Union, Any, cast, overload
|
|
9
11
|
from math import prod
|
|
10
12
|
import jax
|
|
11
13
|
import jax.numpy as jnp
|
|
@@ -15,12 +17,13 @@ import numpy as onp
|
|
|
15
17
|
|
|
16
18
|
from jinns.nn._pinn import PINN
|
|
17
19
|
from jinns.nn._mlp import MLP
|
|
18
|
-
from jinns.parameters._params import Params
|
|
20
|
+
from jinns.parameters._params import Params
|
|
21
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
def _get_param_nb(
|
|
22
|
-
params:
|
|
23
|
-
) -> tuple[int, list]:
|
|
25
|
+
params: PyTree[Array],
|
|
26
|
+
) -> tuple[int, list[int]]:
|
|
24
27
|
"""Returns the number of parameters in a Params object and also
|
|
25
28
|
the cumulative sum when parsing the object.
|
|
26
29
|
|
|
@@ -48,7 +51,7 @@ class HyperPINN(PINN):
|
|
|
48
51
|
|
|
49
52
|
Parameters
|
|
50
53
|
----------
|
|
51
|
-
hyperparams: list = eqx.field(static=True)
|
|
54
|
+
hyperparams: list[str] = eqx.field(static=True)
|
|
52
55
|
A list of keys from Params.eq_params that will be considered as
|
|
53
56
|
hyperparameters for metamodeling.
|
|
54
57
|
hypernet_input_size: int
|
|
@@ -72,12 +75,12 @@ class HyperPINN(PINN):
|
|
|
72
75
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
73
76
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
74
77
|
after the `input_transform` function
|
|
75
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
78
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
76
79
|
A function that will be called before entering the PINN. Its output(s)
|
|
77
80
|
must match the PINN inputs (except for the parameters).
|
|
78
81
|
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
79
82
|
and the parameters. Default is no operation.
|
|
80
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
83
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
81
84
|
A function with arguments begin the same input as the PINN, the PINN
|
|
82
85
|
output and the parameter. This function will be called after exiting the PINN.
|
|
83
86
|
Default is no operation.
|
|
@@ -100,10 +103,10 @@ class HyperPINN(PINN):
|
|
|
100
103
|
eqx_hyper_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
101
104
|
|
|
102
105
|
pinn_params_sum: int = eqx.field(init=False, static=True)
|
|
103
|
-
pinn_params_cumsum: list = eqx.field(init=False, static=True)
|
|
106
|
+
pinn_params_cumsum: list[int] = eqx.field(init=False, static=True)
|
|
104
107
|
|
|
105
|
-
init_params_hyper:
|
|
106
|
-
static_hyper:
|
|
108
|
+
init_params_hyper: HyperPINN = eqx.field(init=False)
|
|
109
|
+
static_hyper: HyperPINN = eqx.field(init=False, static=True)
|
|
107
110
|
|
|
108
111
|
def __post_init__(self, eqx_network, eqx_hyper_network):
|
|
109
112
|
super().__post_init__(
|
|
@@ -115,7 +118,7 @@ class HyperPINN(PINN):
|
|
|
115
118
|
)
|
|
116
119
|
self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.init_params)
|
|
117
120
|
|
|
118
|
-
def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) ->
|
|
121
|
+
def _hyper_to_pinn(self, hyper_output: Float[Array, " output_dim"]) -> PINN:
|
|
119
122
|
"""
|
|
120
123
|
From the output of the hypernetwork, transform to a well formed
|
|
121
124
|
parameters for the pinn network (i.e. with the same PyTree structure as
|
|
@@ -142,15 +145,29 @@ class HyperPINN(PINN):
|
|
|
142
145
|
is_leaf=lambda x: isinstance(x, jnp.ndarray),
|
|
143
146
|
)
|
|
144
147
|
|
|
148
|
+
@overload
|
|
149
|
+
@_PyTree_to_Params
|
|
145
150
|
def __call__(
|
|
146
151
|
self,
|
|
147
|
-
inputs: Float[Array, "input_dim"],
|
|
148
|
-
params:
|
|
152
|
+
inputs: Float[Array, " input_dim"],
|
|
153
|
+
params: PyTree,
|
|
149
154
|
*args,
|
|
150
155
|
**kwargs,
|
|
151
|
-
) -> Float[Array, "output_dim"]:
|
|
156
|
+
) -> Float[Array, " output_dim"]: ...
|
|
157
|
+
|
|
158
|
+
@_PyTree_to_Params
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
inputs: Float[Array, " input_dim"],
|
|
162
|
+
params: Params[Array],
|
|
163
|
+
*args,
|
|
164
|
+
**kwargs,
|
|
165
|
+
) -> Float[Array, " output_dim"]:
|
|
152
166
|
"""
|
|
153
167
|
Evaluate the HyperPINN on some inputs with some params.
|
|
168
|
+
|
|
169
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
170
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
154
171
|
"""
|
|
155
172
|
if len(inputs.shape) == 0:
|
|
156
173
|
# This can happen often when the user directly provides some
|
|
@@ -158,16 +175,17 @@ class HyperPINN(PINN):
|
|
|
158
175
|
# DataGenerators)
|
|
159
176
|
inputs = inputs[None]
|
|
160
177
|
|
|
161
|
-
try:
|
|
162
|
-
|
|
163
|
-
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
164
|
-
|
|
178
|
+
# try:
|
|
179
|
+
hyper = eqx.combine(params.nn_params, self.static_hyper)
|
|
180
|
+
# except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
181
|
+
# hyper = eqx.combine(params, self.static_hyper)
|
|
165
182
|
|
|
166
183
|
eq_params_batch = jnp.concatenate(
|
|
167
|
-
[params.eq_params[k].flatten() for k in self.hyperparams],
|
|
184
|
+
[params.eq_params[k].flatten() for k in self.hyperparams],
|
|
185
|
+
axis=0,
|
|
168
186
|
)
|
|
169
187
|
|
|
170
|
-
hyper_output = hyper(eq_params_batch)
|
|
188
|
+
hyper_output = hyper(eq_params_batch) # type: ignore
|
|
171
189
|
|
|
172
190
|
pinn_params = self._hyper_to_pinn(hyper_output)
|
|
173
191
|
|
|
@@ -187,21 +205,34 @@ class HyperPINN(PINN):
|
|
|
187
205
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
188
206
|
hyperparams: list[str],
|
|
189
207
|
hypernet_input_size: int,
|
|
190
|
-
eqx_network: eqx.nn.MLP = None,
|
|
191
|
-
eqx_hyper_network: eqx.nn.MLP = None,
|
|
208
|
+
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
209
|
+
eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
|
|
192
210
|
key: Key = None,
|
|
193
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
|
|
194
|
-
eqx_list_hyper:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
211
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
212
|
+
eqx_list_hyper: (
|
|
213
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
|
|
214
|
+
) = None,
|
|
215
|
+
input_transform: (
|
|
216
|
+
Callable[
|
|
217
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
218
|
+
Float[Array, " output_dim"],
|
|
219
|
+
]
|
|
220
|
+
| None
|
|
221
|
+
) = None,
|
|
222
|
+
output_transform: (
|
|
223
|
+
Callable[
|
|
224
|
+
[
|
|
225
|
+
Float[Array, " input_dim"],
|
|
226
|
+
Float[Array, " output_dim"],
|
|
227
|
+
Params[Array],
|
|
228
|
+
],
|
|
229
|
+
Float[Array, " output_dim"],
|
|
230
|
+
]
|
|
231
|
+
| None
|
|
232
|
+
) = None,
|
|
233
|
+
slice_solution: slice | None = None,
|
|
203
234
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
204
|
-
) -> tuple[Self,
|
|
235
|
+
) -> tuple[Self, HyperPINN]:
|
|
205
236
|
r"""
|
|
206
237
|
Utility function to create a standard PINN neural network with the equinox
|
|
207
238
|
library.
|
|
@@ -250,11 +281,11 @@ class HyperPINN(PINN):
|
|
|
250
281
|
The `key` argument need not be given.
|
|
251
282
|
Thus typical example is `eqx_list=
|
|
252
283
|
((eqx.nn.Linear, 2, 20),
|
|
253
|
-
jax.nn.tanh,
|
|
284
|
+
(jax.nn.tanh,),
|
|
254
285
|
(eqx.nn.Linear, 20, 20),
|
|
255
|
-
jax.nn.tanh,
|
|
286
|
+
(jax.nn.tanh,),
|
|
256
287
|
(eqx.nn.Linear, 20, 20),
|
|
257
|
-
jax.nn.tanh,
|
|
288
|
+
(jax.nn.tanh,),
|
|
258
289
|
(eqx.nn.Linear, 20, 1)
|
|
259
290
|
)`.
|
|
260
291
|
eqx_list_hyper
|
|
@@ -268,11 +299,11 @@ class HyperPINN(PINN):
|
|
|
268
299
|
The `key` argument need not be given.
|
|
269
300
|
Thus typical example is `eqx_list=
|
|
270
301
|
((eqx.nn.Linear, 2, 20),
|
|
271
|
-
jax.nn.tanh,
|
|
302
|
+
(jax.nn.tanh,),
|
|
272
303
|
(eqx.nn.Linear, 20, 20),
|
|
273
|
-
jax.nn.tanh,
|
|
304
|
+
(jax.nn.tanh,),
|
|
274
305
|
(eqx.nn.Linear, 20, 20),
|
|
275
|
-
jax.nn.tanh,
|
|
306
|
+
(jax.nn.tanh,),
|
|
276
307
|
(eqx.nn.Linear, 20, 1)
|
|
277
308
|
)`.
|
|
278
309
|
input_transform
|
|
@@ -343,10 +374,13 @@ class HyperPINN(PINN):
|
|
|
343
374
|
(eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
|
|
344
375
|
)
|
|
345
376
|
else:
|
|
346
|
-
eqx_list_hyper = (
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
377
|
+
eqx_list_hyper = cast(
|
|
378
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
379
|
+
(
|
|
380
|
+
eqx_list_hyper[:-2]
|
|
381
|
+
+ ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
|
|
382
|
+
+ eqx_list_hyper[-1]
|
|
383
|
+
),
|
|
350
384
|
)
|
|
351
385
|
if len(eqx_list_hyper[0]) > 1:
|
|
352
386
|
eqx_list_hyper = (
|
|
@@ -357,21 +391,24 @@ class HyperPINN(PINN):
|
|
|
357
391
|
),
|
|
358
392
|
) + eqx_list_hyper[1:]
|
|
359
393
|
else:
|
|
360
|
-
eqx_list_hyper = (
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
394
|
+
eqx_list_hyper = cast(
|
|
395
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
396
|
+
(
|
|
397
|
+
eqx_list_hyper[0]
|
|
398
|
+
+ (
|
|
399
|
+
(
|
|
400
|
+
(eqx_list_hyper[1][0],)
|
|
401
|
+
+ (hypernet_input_size,)
|
|
402
|
+
+ (eqx_list_hyper[1][2],) # type: ignore because we suppose that the second element of tuple is nec.of length > 1 since we expect smth like eqx.nn.Linear
|
|
403
|
+
),
|
|
404
|
+
)
|
|
405
|
+
+ eqx_list_hyper[2:]
|
|
406
|
+
),
|
|
370
407
|
)
|
|
371
408
|
key, subkey = jax.random.split(key, 2)
|
|
372
409
|
# with warnings.catch_warnings():
|
|
373
410
|
# warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
374
|
-
eqx_hyper_network = MLP(key=subkey, eqx_list=eqx_list_hyper)
|
|
411
|
+
eqx_hyper_network = cast(MLP, MLP(key=subkey, eqx_list=eqx_list_hyper))
|
|
375
412
|
|
|
376
413
|
### End of finetuning the hypernetwork architecture
|
|
377
414
|
|
|
@@ -386,10 +423,10 @@ class HyperPINN(PINN):
|
|
|
386
423
|
hyperpinn = cls(
|
|
387
424
|
eqx_network=eqx_network,
|
|
388
425
|
eqx_hyper_network=eqx_hyper_network,
|
|
389
|
-
slice_solution=slice_solution,
|
|
426
|
+
slice_solution=slice_solution, # type: ignore
|
|
390
427
|
eq_type=eq_type,
|
|
391
|
-
input_transform=input_transform,
|
|
392
|
-
output_transform=output_transform,
|
|
428
|
+
input_transform=input_transform, # type: ignore
|
|
429
|
+
output_transform=output_transform, # type: ignore
|
|
393
430
|
hyperparams=hyperparams,
|
|
394
431
|
hypernet_input_size=hypernet_input_size,
|
|
395
432
|
filter_spec=filter_spec,
|
jinns/nn/_mlp.py
CHANGED
|
@@ -2,16 +2,30 @@
|
|
|
2
2
|
Implements utility function to create PINNs
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Literal, Self, Union, Any, TYPE_CHECKING, cast
|
|
6
8
|
from dataclasses import InitVar
|
|
7
9
|
import jax
|
|
8
10
|
import equinox as eqx
|
|
9
|
-
|
|
11
|
+
from typing import Protocol
|
|
10
12
|
from jaxtyping import Array, Key, PyTree, Float
|
|
11
13
|
|
|
12
14
|
from jinns.parameters._params import Params
|
|
13
15
|
from jinns.nn._pinn import PINN
|
|
14
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
|
|
19
|
+
class CallableMLPModule(Protocol):
|
|
20
|
+
"""
|
|
21
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
22
|
+
https://github.com/patrick-kidger/equinox/issues/1002
|
|
23
|
+
We chose the strutural subtyping of protocols instead of subclassing an
|
|
24
|
+
eqx.Module just to add a __call__ here
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __call__(self, *_, **__) -> Array: ...
|
|
28
|
+
|
|
15
29
|
|
|
16
30
|
class MLP(eqx.Module):
|
|
17
31
|
"""
|
|
@@ -21,7 +35,7 @@ class MLP(eqx.Module):
|
|
|
21
35
|
----------
|
|
22
36
|
key : InitVar[Key]
|
|
23
37
|
A jax random key for the layer initializations.
|
|
24
|
-
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
38
|
+
eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
|
|
25
39
|
A tuple of tuples of successive equinox modules and activation functions to
|
|
26
40
|
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
27
41
|
activation function as first item, other items represents arguments
|
|
@@ -29,23 +43,23 @@ class MLP(eqx.Module):
|
|
|
29
43
|
The `key` argument need not be given.
|
|
30
44
|
Thus typical example is `eqx_list=
|
|
31
45
|
((eqx.nn.Linear, 2, 20),
|
|
32
|
-
jax.nn.tanh,
|
|
46
|
+
(jax.nn.tanh,),
|
|
33
47
|
(eqx.nn.Linear, 20, 20),
|
|
34
|
-
jax.nn.tanh,
|
|
48
|
+
(jax.nn.tanh,),
|
|
35
49
|
(eqx.nn.Linear, 20, 20),
|
|
36
|
-
jax.nn.tanh,
|
|
50
|
+
(jax.nn.tanh,),
|
|
37
51
|
(eqx.nn.Linear, 20, 1)
|
|
38
52
|
)`.
|
|
39
53
|
"""
|
|
40
54
|
|
|
41
55
|
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
42
|
-
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] =
|
|
43
|
-
kw_only=True
|
|
56
|
+
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
57
|
+
eqx.field(kw_only=True)
|
|
44
58
|
)
|
|
45
59
|
|
|
46
60
|
# NOTE that the following should NOT be declared as static otherwise the
|
|
47
61
|
# eqx.partition that we use in the PINN module will misbehave
|
|
48
|
-
layers: list[
|
|
62
|
+
layers: list[CallableMLPModule | Callable[[Array], Array]] = eqx.field(init=False)
|
|
49
63
|
|
|
50
64
|
def __post_init__(self, key, eqx_list):
|
|
51
65
|
self.layers = []
|
|
@@ -63,7 +77,7 @@ class MLP(eqx.Module):
|
|
|
63
77
|
self.layers.append(l[0](*l[1:], key=subkey))
|
|
64
78
|
k += 1
|
|
65
79
|
|
|
66
|
-
def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
|
|
80
|
+
def __call__(self, t: Float[Array, " input_dim"]) -> Float[Array, " output_dim"]:
|
|
67
81
|
for layer in self.layers:
|
|
68
82
|
t = layer(t)
|
|
69
83
|
return t
|
|
@@ -81,19 +95,30 @@ class PINN_MLP(PINN):
|
|
|
81
95
|
def create(
|
|
82
96
|
cls,
|
|
83
97
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
84
|
-
eqx_network: eqx.nn.MLP = None,
|
|
98
|
+
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
85
99
|
key: Key = None,
|
|
86
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
|
|
87
|
-
input_transform:
|
|
88
|
-
[
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
100
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
101
|
+
input_transform: (
|
|
102
|
+
Callable[
|
|
103
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
104
|
+
Float[Array, " output_dim"],
|
|
105
|
+
]
|
|
106
|
+
| None
|
|
107
|
+
) = None,
|
|
108
|
+
output_transform: (
|
|
109
|
+
Callable[
|
|
110
|
+
[
|
|
111
|
+
Float[Array, " input_dim"],
|
|
112
|
+
Float[Array, " output_dim"],
|
|
113
|
+
Params[Array],
|
|
114
|
+
],
|
|
115
|
+
Float[Array, " output_dim"],
|
|
116
|
+
]
|
|
117
|
+
| None
|
|
118
|
+
) = None,
|
|
119
|
+
slice_solution: slice | None = None,
|
|
95
120
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
96
|
-
) -> tuple[Self,
|
|
121
|
+
) -> tuple[Self, PINN]:
|
|
97
122
|
r"""
|
|
98
123
|
Instanciate standard PINN MLP object. The actual NN is either passed as
|
|
99
124
|
a eqx.nn.MLP (`eqx_network` argument) or constructed as a custom
|
|
@@ -179,14 +204,14 @@ class PINN_MLP(PINN):
|
|
|
179
204
|
raise ValueError(
|
|
180
205
|
"If eqx_network is None, then key and eqx_list must be provided"
|
|
181
206
|
)
|
|
182
|
-
eqx_network = MLP(key=key, eqx_list=eqx_list)
|
|
207
|
+
eqx_network = cast(MLP, MLP(key=key, eqx_list=eqx_list))
|
|
183
208
|
|
|
184
209
|
mlp = cls(
|
|
185
210
|
eqx_network=eqx_network,
|
|
186
|
-
slice_solution=slice_solution,
|
|
211
|
+
slice_solution=slice_solution, # type: ignore
|
|
187
212
|
eq_type=eq_type,
|
|
188
|
-
input_transform=input_transform,
|
|
189
|
-
output_transform=output_transform,
|
|
213
|
+
input_transform=input_transform, # type: ignore
|
|
214
|
+
output_transform=output_transform, # type: ignore
|
|
190
215
|
filter_spec=filter_spec,
|
|
191
216
|
)
|
|
192
217
|
return mlp, mlp.init_params
|
jinns/nn/_pinn.py
CHANGED
|
@@ -2,15 +2,19 @@
|
|
|
2
2
|
Implement abstract class for PINN architectures
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Union, Any, Literal, overload
|
|
6
8
|
from dataclasses import InitVar
|
|
7
9
|
import equinox as eqx
|
|
8
10
|
from jaxtyping import Float, Array, PyTree
|
|
9
11
|
import jax.numpy as jnp
|
|
10
|
-
from jinns.parameters._params import Params
|
|
12
|
+
from jinns.parameters._params import Params
|
|
13
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
14
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
11
15
|
|
|
12
16
|
|
|
13
|
-
class PINN(
|
|
17
|
+
class PINN(AbstractPINN):
|
|
14
18
|
r"""
|
|
15
19
|
Base class for PINN objects. It can be seen as a wrapper on
|
|
16
20
|
an `eqx.Module` which actually implement the NN architectures, with extra
|
|
@@ -57,12 +61,12 @@ class PINN(eqx.Module):
|
|
|
57
61
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
58
62
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
59
63
|
after the `input_transform` function.
|
|
60
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
64
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
61
65
|
A function that will be called before entering the PINN. Its output(s)
|
|
62
66
|
must match the PINN inputs (except for the parameters).
|
|
63
67
|
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
64
68
|
and the parameters. Default is no operation.
|
|
65
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
69
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
66
70
|
A function with arguments begin the same input as the PINN, the PINN
|
|
67
71
|
output and the parameter. This function will be called after exiting the PINN.
|
|
68
72
|
Default is no operation.
|
|
@@ -84,16 +88,16 @@ class PINN(eqx.Module):
|
|
|
84
88
|
"nonstatio_PDE"]`
|
|
85
89
|
"""
|
|
86
90
|
|
|
87
|
-
slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
88
91
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
89
92
|
static=True, kw_only=True
|
|
90
93
|
)
|
|
94
|
+
slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
91
95
|
input_transform: Callable[
|
|
92
|
-
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
96
|
+
[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]
|
|
93
97
|
] = eqx.field(static=True, kw_only=True, default=None)
|
|
94
98
|
output_transform: Callable[
|
|
95
|
-
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
96
|
-
Float[Array, "output_dim"],
|
|
99
|
+
[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]],
|
|
100
|
+
Float[Array, " output_dim"],
|
|
97
101
|
] = eqx.field(static=True, kw_only=True, default=None)
|
|
98
102
|
|
|
99
103
|
eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
@@ -101,11 +105,10 @@ class PINN(eqx.Module):
|
|
|
101
105
|
static=True, kw_only=True, default=eqx.is_inexact_array
|
|
102
106
|
)
|
|
103
107
|
|
|
104
|
-
init_params:
|
|
105
|
-
static:
|
|
108
|
+
init_params: PINN = eqx.field(init=False)
|
|
109
|
+
static: PINN = eqx.field(init=False, static=True)
|
|
106
110
|
|
|
107
111
|
def __post_init__(self, eqx_network):
|
|
108
|
-
|
|
109
112
|
if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
110
113
|
raise RuntimeError("Wrong parameter value for eq_type")
|
|
111
114
|
# saving the static part of the model and initial parameters
|
|
@@ -154,18 +157,32 @@ class PINN(eqx.Module):
|
|
|
154
157
|
|
|
155
158
|
return network(inputs)
|
|
156
159
|
|
|
160
|
+
@overload
|
|
161
|
+
@_PyTree_to_Params
|
|
162
|
+
def __call__(
|
|
163
|
+
self,
|
|
164
|
+
inputs: Float[Array, " input_dim"],
|
|
165
|
+
params: PyTree,
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
) -> Float[Array, " output_dim"]: ...
|
|
169
|
+
|
|
170
|
+
@_PyTree_to_Params
|
|
157
171
|
def __call__(
|
|
158
172
|
self,
|
|
159
|
-
inputs: Float[Array, "input_dim"],
|
|
160
|
-
params: Params
|
|
173
|
+
inputs: Float[Array, " input_dim"],
|
|
174
|
+
params: Params[Array],
|
|
161
175
|
*args,
|
|
162
176
|
**kwargs,
|
|
163
|
-
) -> Float[Array, "output_dim"]:
|
|
177
|
+
) -> Float[Array, " output_dim"]:
|
|
164
178
|
"""
|
|
165
179
|
A proper __call__ implementation performs an eqx.combine here with
|
|
166
180
|
`params` and `self.static` to recreate the callable eqx.Module
|
|
167
181
|
architecture. The rest of the content of this function is dependent on
|
|
168
182
|
the network.
|
|
183
|
+
|
|
184
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
185
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
169
186
|
"""
|
|
170
187
|
|
|
171
188
|
if len(inputs.shape) == 0:
|
|
@@ -174,10 +191,7 @@ class PINN(eqx.Module):
|
|
|
174
191
|
# DataGenerators)
|
|
175
192
|
inputs = inputs[None]
|
|
176
193
|
|
|
177
|
-
|
|
178
|
-
model = eqx.combine(params.nn_params, self.static)
|
|
179
|
-
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
180
|
-
model = eqx.combine(params, self.static)
|
|
194
|
+
model = eqx.combine(params.nn_params, self.static)
|
|
181
195
|
|
|
182
196
|
# evaluate the model
|
|
183
197
|
res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)
|