jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/utils/_containers.py
CHANGED
|
@@ -1,57 +1,51 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
equinox Modules used as containers
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
import
|
|
10
|
-
from
|
|
11
|
-
from
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class ValidationContainer(NamedTuple):
|
|
31
|
-
loss: Union[
|
|
32
|
-
LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
|
|
33
|
-
]
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Dict
|
|
10
|
+
from jaxtyping import PyTree, Array, Float, Bool
|
|
11
|
+
from optax import OptState
|
|
12
|
+
import equinox as eqx
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from jinns.utils._types import *
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DataGeneratorContainer(eqx.Module):
|
|
19
|
+
data: AnyDataGenerator
|
|
20
|
+
param_data: DataGeneratorParameter | None = None
|
|
21
|
+
obs_data: DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None = (
|
|
22
|
+
None
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ValidationContainer(eqx.Module):
|
|
27
|
+
loss: AnyLoss | None
|
|
34
28
|
data: DataGeneratorContainer
|
|
35
29
|
hyperparams: PyTree = None
|
|
36
|
-
loss_values:
|
|
30
|
+
loss_values: Float[Array, "n_iter"] | None = None
|
|
37
31
|
|
|
38
32
|
|
|
39
|
-
class OptimizationContainer(
|
|
40
|
-
params:
|
|
41
|
-
last_non_nan_params:
|
|
42
|
-
opt_state:
|
|
33
|
+
class OptimizationContainer(eqx.Module):
|
|
34
|
+
params: Params
|
|
35
|
+
last_non_nan_params: Params
|
|
36
|
+
opt_state: OptState
|
|
43
37
|
|
|
44
38
|
|
|
45
|
-
class OptimizationExtraContainer(
|
|
39
|
+
class OptimizationExtraContainer(eqx.Module):
|
|
46
40
|
curr_seq: int
|
|
47
|
-
|
|
48
|
-
early_stopping:
|
|
41
|
+
best_val_params: Params
|
|
42
|
+
early_stopping: Bool = False
|
|
49
43
|
|
|
50
44
|
|
|
51
|
-
class LossContainer(
|
|
52
|
-
stored_loss_terms:
|
|
53
|
-
train_loss_values:
|
|
45
|
+
class LossContainer(eqx.Module):
|
|
46
|
+
stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
|
|
47
|
+
train_loss_values: Float[Array, "n_iter"]
|
|
54
48
|
|
|
55
49
|
|
|
56
|
-
class StoredObjectContainer(
|
|
57
|
-
stored_params:
|
|
50
|
+
class StoredObjectContainer(eqx.Module):
|
|
51
|
+
stored_params: list | None
|
jinns/utils/_hyperpinn.py
CHANGED
|
@@ -3,90 +3,132 @@ Implements utility function to create HYPERPINNs
|
|
|
3
3
|
https://arxiv.org/pdf/2111.01008.pdf
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import warnings
|
|
7
|
+
from dataclasses import InitVar
|
|
8
|
+
from typing import Callable, Literal
|
|
6
9
|
import copy
|
|
7
10
|
from math import prod
|
|
8
|
-
import numpy as onp
|
|
9
11
|
import jax
|
|
10
12
|
import jax.numpy as jnp
|
|
11
13
|
from jax.tree_util import tree_leaves, tree_map
|
|
12
|
-
from
|
|
14
|
+
from jaxtyping import Array, Float, PyTree, Int, Key
|
|
13
15
|
import equinox as eqx
|
|
16
|
+
import numpy as onp
|
|
14
17
|
|
|
15
18
|
from jinns.utils._pinn import PINN, _MLP
|
|
19
|
+
from jinns.parameters._params import Params
|
|
16
20
|
|
|
17
21
|
|
|
18
|
-
def _get_param_nb(
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
sum when parsing the
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
def _get_param_nb(
|
|
23
|
+
params: Params,
|
|
24
|
+
) -> tuple[Int[onp.ndarray, "1"], Int[onp.ndarray, "n_layers"]]:
|
|
25
|
+
"""Returns the number of parameters in a Params object and also
|
|
26
|
+
the cumulative sum when parsing the object.
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
params :
|
|
32
|
+
A Params object.
|
|
25
33
|
"""
|
|
26
34
|
dim_prod_all_arrays = [
|
|
27
35
|
prod(a.shape)
|
|
28
36
|
for a in tree_leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
|
|
29
37
|
]
|
|
30
|
-
return sum(dim_prod_all_arrays), onp.cumsum(dim_prod_all_arrays)
|
|
38
|
+
return onp.asarray(sum(dim_prod_all_arrays)), onp.cumsum(dim_prod_all_arrays)
|
|
31
39
|
|
|
32
40
|
|
|
33
41
|
class HYPERPINN(PINN):
|
|
34
42
|
"""
|
|
35
|
-
|
|
36
|
-
|
|
43
|
+
A HYPERPINN object compatible with the rest of jinns.
|
|
44
|
+
Composed of a PINN and an HYPER network. The HYPERPINN is typically
|
|
45
|
+
instanciated using with `create_HYPERPINN`. However, a user could directly
|
|
46
|
+
creates their HYPERPINN using this
|
|
47
|
+
class by passing an eqx.Module for argument `mlp` (resp. for argument
|
|
48
|
+
`hyper_mlp`) that plays the role of the NN (resp. hyper NN) and that is
|
|
49
|
+
already instanciated.
|
|
37
50
|
|
|
38
|
-
|
|
39
|
-
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
40
53
|
hyperparams: list = eqx.field(static=True)
|
|
54
|
+
A list of keys from Params.eq_params that will be considered as
|
|
55
|
+
hyperparameters for metamodeling.
|
|
41
56
|
hypernet_input_size: int
|
|
42
|
-
|
|
43
|
-
|
|
57
|
+
An integer. The input size of the MLP used for the hypernetwork. Must
|
|
58
|
+
be equal to the flattened concatenations for the array of parameters
|
|
59
|
+
designated by the `hyperparams` argument.
|
|
60
|
+
slice_solution : slice
|
|
61
|
+
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
62
|
+
dedicated to the actual equation solution. Default None
|
|
63
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
64
|
+
when the PINN is also used to output equation parameters for example
|
|
65
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
66
|
+
user provided argument takes care of it).
|
|
67
|
+
eq_type : str
|
|
68
|
+
A string with three possibilities.
|
|
69
|
+
"ODE": the HYPERPINN is called with one input `t`.
|
|
70
|
+
"statio_PDE": the HYPERPINN is called with one input `x`, `x`
|
|
71
|
+
can be high dimensional.
|
|
72
|
+
"nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
|
|
73
|
+
can be high dimensional.
|
|
74
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
75
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
76
|
+
after the `input_transform` function
|
|
77
|
+
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
78
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
79
|
+
must match the PINN inputs (except for the parameters).
|
|
80
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
81
|
+
and the parameters. Default is no operation.
|
|
82
|
+
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
83
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
84
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
85
|
+
Default is no operation.
|
|
86
|
+
output_slice : slice, default=None
|
|
87
|
+
A jnp.s\_[] to determine the different dimension for the HYPERPINN.
|
|
88
|
+
See `shared_pinn_outputs` argument of `create_HYPERPINN`.
|
|
89
|
+
mlp : eqx.Module
|
|
90
|
+
The actual neural network instanciated as an eqx.Module.
|
|
91
|
+
hyper_mlp : eqx.Module
|
|
92
|
+
The actual hyper neural network instanciated as an eqx.Module.
|
|
93
|
+
"""
|
|
44
94
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
)
|
|
57
|
-
|
|
95
|
+
hyperparams: list[str] = eqx.field(static=True, kw_only=True)
|
|
96
|
+
hypernet_input_size: int = eqx.field(kw_only=True)
|
|
97
|
+
|
|
98
|
+
hyper_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
99
|
+
mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
100
|
+
|
|
101
|
+
params_hyper: PyTree = eqx.field(init=False)
|
|
102
|
+
static_hyper: PyTree = eqx.field(init=False, static=True)
|
|
103
|
+
pinn_params_sum: Int[onp.ndarray, "1"] = eqx.field(init=False, static=True)
|
|
104
|
+
pinn_params_cumsum: Int[onp.ndarray, "n_layers"] = eqx.field(
|
|
105
|
+
init=False, static=True
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def __post_init__(self, mlp, hyper_mlp):
|
|
109
|
+
super().__post_init__(
|
|
58
110
|
mlp,
|
|
59
|
-
slice_solution,
|
|
60
|
-
eq_type,
|
|
61
|
-
input_transform,
|
|
62
|
-
output_transform,
|
|
63
|
-
output_slice,
|
|
64
111
|
)
|
|
65
112
|
self.params_hyper, self.static_hyper = eqx.partition(
|
|
66
113
|
hyper_mlp, eqx.is_inexact_array
|
|
67
114
|
)
|
|
68
|
-
self.hyperparams = hyperparams
|
|
69
|
-
self.hypernet_input_size = hypernet_input_size
|
|
70
115
|
self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
|
|
71
116
|
|
|
72
|
-
def init_params(self):
|
|
117
|
+
def init_params(self) -> Params:
|
|
118
|
+
"""
|
|
119
|
+
Returns an initial set of parameters
|
|
120
|
+
"""
|
|
73
121
|
return self.params_hyper
|
|
74
122
|
|
|
75
|
-
def
|
|
123
|
+
def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
|
|
76
124
|
"""
|
|
77
125
|
From the output of the hypernetwork we set the well formed
|
|
78
126
|
parameters of the pinn (`self.params`)
|
|
79
127
|
"""
|
|
80
128
|
pinn_params_flat = eqx.tree_at(
|
|
81
|
-
lambda p: tree_leaves(p, is_leaf=
|
|
129
|
+
lambda p: tree_leaves(p, is_leaf=eqx.is_array),
|
|
82
130
|
self.params,
|
|
83
|
-
|
|
84
|
-
+ [
|
|
85
|
-
hyper_output[
|
|
86
|
-
self.pinn_params_cumsum[i] : self.pinn_params_cumsum[i + 1]
|
|
87
|
-
]
|
|
88
|
-
for i in range(len(self.pinn_params_cumsum) - 1)
|
|
89
|
-
],
|
|
131
|
+
jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
|
|
90
132
|
)
|
|
91
133
|
|
|
92
134
|
return tree_map(
|
|
@@ -96,26 +138,31 @@ class HYPERPINN(PINN):
|
|
|
96
138
|
is_leaf=lambda x: isinstance(x, jnp.ndarray),
|
|
97
139
|
)
|
|
98
140
|
|
|
99
|
-
def
|
|
141
|
+
def eval_nn(
|
|
142
|
+
self,
|
|
143
|
+
inputs: Float[Array, "input_dim"],
|
|
144
|
+
params: Params | PyTree,
|
|
145
|
+
) -> Float[Array, "output_dim"]:
|
|
100
146
|
"""
|
|
101
|
-
|
|
102
|
-
call _eval_nn which always have the same content.
|
|
147
|
+
Evaluate the HYPERPINN on some inputs with some params.
|
|
103
148
|
"""
|
|
104
149
|
try:
|
|
105
|
-
hyper = eqx.combine(params
|
|
106
|
-
except (KeyError, TypeError) as e: # give more flexibility
|
|
150
|
+
hyper = eqx.combine(params.nn_params, self.static_hyper)
|
|
151
|
+
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
107
152
|
hyper = eqx.combine(params, self.static_hyper)
|
|
108
153
|
|
|
109
154
|
eq_params_batch = jnp.concatenate(
|
|
110
|
-
[params
|
|
155
|
+
[params.eq_params[k].flatten() for k in self.hyperparams], axis=0
|
|
111
156
|
)
|
|
112
157
|
|
|
113
158
|
hyper_output = hyper(eq_params_batch)
|
|
114
159
|
|
|
115
|
-
pinn_params = self.
|
|
160
|
+
pinn_params = self._hyper_to_pinn(hyper_output)
|
|
116
161
|
|
|
117
162
|
pinn = eqx.combine(pinn_params, self.static)
|
|
118
|
-
res = output_transform(
|
|
163
|
+
res = self.output_transform(
|
|
164
|
+
inputs, pinn(self.input_transform(inputs, params)).squeeze(), params
|
|
165
|
+
)
|
|
119
166
|
|
|
120
167
|
if self.output_slice is not None:
|
|
121
168
|
res = res[self.output_slice]
|
|
@@ -127,18 +174,23 @@ class HYPERPINN(PINN):
|
|
|
127
174
|
|
|
128
175
|
|
|
129
176
|
def create_HYPERPINN(
|
|
130
|
-
key,
|
|
131
|
-
eqx_list,
|
|
132
|
-
eq_type,
|
|
133
|
-
hyperparams,
|
|
134
|
-
hypernet_input_size,
|
|
135
|
-
dim_x=0,
|
|
136
|
-
input_transform
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
177
|
+
key: Key,
|
|
178
|
+
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
179
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
180
|
+
hyperparams: list[str],
|
|
181
|
+
hypernet_input_size: int,
|
|
182
|
+
dim_x: int = 0,
|
|
183
|
+
input_transform: Callable[
|
|
184
|
+
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
185
|
+
] = None,
|
|
186
|
+
output_transform: Callable[
|
|
187
|
+
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
188
|
+
Float[Array, "output_dim"],
|
|
189
|
+
] = None,
|
|
190
|
+
slice_solution: slice = None,
|
|
191
|
+
shared_pinn_outputs: slice = None,
|
|
192
|
+
eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
|
|
193
|
+
) -> HYPERPINN | list[HYPERPINN]:
|
|
142
194
|
r"""
|
|
143
195
|
Utility function to create a standard PINN neural network with the equinox
|
|
144
196
|
library.
|
|
@@ -146,59 +198,61 @@ def create_HYPERPINN(
|
|
|
146
198
|
Parameters
|
|
147
199
|
----------
|
|
148
200
|
key
|
|
149
|
-
A
|
|
201
|
+
A JAX random key that will be used to initialize the network
|
|
202
|
+
parameters.
|
|
150
203
|
eqx_list
|
|
151
|
-
A
|
|
152
|
-
describe the PINN architecture. The inner
|
|
153
|
-
|
|
204
|
+
A tuple of tuples of successive equinox modules and activation functions to
|
|
205
|
+
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
206
|
+
activation function as first item, other items represent arguments
|
|
154
207
|
that could be required (eg. the size of the layer).
|
|
155
|
-
|
|
208
|
+
The `key` argument need not be given.
|
|
156
209
|
Thus typical example is `eqx_list=
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
210
|
+
((eqx.nn.Linear, 2, 20),
|
|
211
|
+
jax.nn.tanh,
|
|
212
|
+
(eqx.nn.Linear, 20, 20),
|
|
213
|
+
jax.nn.tanh,
|
|
214
|
+
(eqx.nn.Linear, 20, 20),
|
|
215
|
+
jax.nn.tanh,
|
|
216
|
+
(eqx.nn.Linear, 20, 1)
|
|
217
|
+
)`.
|
|
165
218
|
eq_type
|
|
166
219
|
A string with three possibilities.
|
|
167
|
-
"ODE": the
|
|
168
|
-
"statio_PDE": the
|
|
220
|
+
"ODE": the HYPERPINN is called with one input `t`.
|
|
221
|
+
"statio_PDE": the HYPERPINN is called with one input `x`, `x`
|
|
169
222
|
can be high dimensional.
|
|
170
|
-
"nonstatio_PDE": the
|
|
223
|
+
"nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
|
|
171
224
|
can be high dimensional.
|
|
172
225
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
173
226
|
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
174
227
|
after the `input_transform` function
|
|
175
228
|
hyperparams
|
|
176
|
-
A list of keys from
|
|
177
|
-
hyperparameters for metamodeling
|
|
229
|
+
A list of keys from Params.eq_params that will be considered as
|
|
230
|
+
hyperparameters for metamodeling.
|
|
178
231
|
hypernet_input_size
|
|
179
232
|
An integer. The input size of the MLP used for the hypernetwork. Must
|
|
180
233
|
be equal to the flattened concatenations for the array of parameters
|
|
181
|
-
designated by the `hyperparams` argument
|
|
234
|
+
designated by the `hyperparams` argument.
|
|
182
235
|
dim_x
|
|
183
|
-
An integer. The dimension of `x`. Default `0
|
|
236
|
+
An integer. The dimension of `x`. Default `0`.
|
|
184
237
|
input_transform
|
|
185
238
|
A function that will be called before entering the PINN. Its output(s)
|
|
186
|
-
must match the PINN inputs
|
|
187
|
-
|
|
239
|
+
must match the PINN inputs (except for the parameters).
|
|
240
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
241
|
+
and the parameters. Default is no operation.
|
|
188
242
|
output_transform
|
|
189
|
-
A function with arguments the same input
|
|
190
|
-
output
|
|
191
|
-
operation
|
|
243
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
244
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
245
|
+
Default is no operation.
|
|
192
246
|
slice_solution
|
|
193
247
|
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
194
248
|
dedicated to the actual equation solution. Default None
|
|
195
249
|
means that slice_solution = the whole PINN output. This argument is useful
|
|
196
250
|
when the PINN is also used to output equation parameters for example
|
|
197
251
|
Note that it must be a slice and not an integer (a preprocessing of the
|
|
198
|
-
user provided argument takes care of it)
|
|
252
|
+
user provided argument takes care of it).
|
|
199
253
|
shared_pinn_outputs
|
|
200
254
|
Default is None, for a stantard PINN.
|
|
201
|
-
A tuple of jnp.
|
|
255
|
+
A tuple of jnp.s\_[] (slices) to determine the different output for each
|
|
202
256
|
network. In this case we return a list of PINNs, one for each output in
|
|
203
257
|
shared_pinn_outputs. This is useful to create PINNs that share the
|
|
204
258
|
same network and same parameters; **the user must then use the same
|
|
@@ -216,7 +270,11 @@ def create_HYPERPINN(
|
|
|
216
270
|
|
|
217
271
|
Returns
|
|
218
272
|
-------
|
|
219
|
-
|
|
273
|
+
hyperpinn
|
|
274
|
+
A HYPERPINN instance or, when `shared_pinn_ouput` is not None,
|
|
275
|
+
a list of HYPERPINN instances with the same structure is returned,
|
|
276
|
+
only differing by there final slicing of the network output.
|
|
277
|
+
|
|
220
278
|
|
|
221
279
|
Raises
|
|
222
280
|
------
|
|
@@ -259,53 +317,94 @@ def create_HYPERPINN(
|
|
|
259
317
|
|
|
260
318
|
if output_transform is None:
|
|
261
319
|
|
|
262
|
-
def output_transform(_in_pinn, _out_pinn):
|
|
320
|
+
def output_transform(_in_pinn, _out_pinn, _params):
|
|
263
321
|
return _out_pinn
|
|
264
322
|
|
|
265
323
|
key, subkey = jax.random.split(key, 2)
|
|
266
|
-
mlp = _MLP(subkey, eqx_list)
|
|
324
|
+
mlp = _MLP(key=subkey, eqx_list=eqx_list)
|
|
267
325
|
# quick partitioning to get the params to get the correct number of neurons
|
|
268
326
|
# for the last layer of hyper network
|
|
269
327
|
params_mlp, _ = eqx.partition(mlp, eqx.is_inexact_array)
|
|
270
328
|
pinn_params_sum, _ = _get_param_nb(params_mlp)
|
|
271
329
|
# the number of parameters for the pinn will be the number of ouputs
|
|
272
330
|
# for the hyper network
|
|
273
|
-
|
|
274
|
-
eqx_list_hyper[
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
eqx_list_hyper
|
|
279
|
-
|
|
280
|
-
|
|
331
|
+
if len(eqx_list_hyper[-1]) > 1:
|
|
332
|
+
eqx_list_hyper = eqx_list_hyper[:-1] + (
|
|
333
|
+
(eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
eqx_list_hyper = (
|
|
337
|
+
eqx_list_hyper[:-2]
|
|
338
|
+
+ ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
|
|
339
|
+
+ eqx_list_hyper[-1]
|
|
340
|
+
)
|
|
341
|
+
if len(eqx_list_hyper[0]) > 1:
|
|
342
|
+
eqx_list_hyper = (
|
|
343
|
+
(
|
|
344
|
+
(eqx_list_hyper[0][0],)
|
|
345
|
+
+ (hypernet_input_size,)
|
|
346
|
+
+ (eqx_list_hyper[0][2],)
|
|
347
|
+
),
|
|
348
|
+
) + eqx_list_hyper[1:]
|
|
349
|
+
else:
|
|
350
|
+
eqx_list_hyper = (
|
|
351
|
+
eqx_list_hyper[0]
|
|
352
|
+
+ (
|
|
353
|
+
(
|
|
354
|
+
(eqx_list_hyper[1][0],)
|
|
355
|
+
+ (hypernet_input_size,)
|
|
356
|
+
+ (eqx_list_hyper[1][2],)
|
|
357
|
+
),
|
|
358
|
+
)
|
|
359
|
+
+ eqx_list_hyper[2:]
|
|
360
|
+
)
|
|
281
361
|
key, subkey = jax.random.split(key, 2)
|
|
282
|
-
|
|
362
|
+
|
|
363
|
+
with warnings.catch_warnings():
|
|
364
|
+
# TODO check why this warning is raised here and not in the PINN
|
|
365
|
+
# context ?
|
|
366
|
+
warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
367
|
+
hyper_mlp = _MLP(key=subkey, eqx_list=eqx_list_hyper)
|
|
283
368
|
|
|
284
369
|
if shared_pinn_outputs is not None:
|
|
285
370
|
hyperpinns = []
|
|
286
371
|
for output_slice in shared_pinn_outputs:
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
372
|
+
with warnings.catch_warnings():
|
|
373
|
+
# Catch the equinox warning because we put the number of
|
|
374
|
+
# parameters as static while being jnp.Array. This this time
|
|
375
|
+
# this is correct to do so, because they are used as indices
|
|
376
|
+
# and will never be modified
|
|
377
|
+
warnings.filterwarnings(
|
|
378
|
+
"ignore", message="A JAX array is being set as static!"
|
|
379
|
+
)
|
|
380
|
+
hyperpinn = HYPERPINN(
|
|
381
|
+
mlp=mlp,
|
|
382
|
+
hyper_mlp=hyper_mlp,
|
|
383
|
+
slice_solution=slice_solution,
|
|
384
|
+
eq_type=eq_type,
|
|
385
|
+
input_transform=input_transform,
|
|
386
|
+
output_transform=output_transform,
|
|
387
|
+
hyperparams=hyperparams,
|
|
388
|
+
hypernet_input_size=hypernet_input_size,
|
|
389
|
+
output_slice=output_slice,
|
|
390
|
+
)
|
|
298
391
|
hyperpinns.append(hyperpinn)
|
|
299
392
|
return hyperpinns
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
393
|
+
with warnings.catch_warnings():
|
|
394
|
+
# Catch the equinox warning because we put the number of
|
|
395
|
+
# parameters as static while being jnp.Array. This this time
|
|
396
|
+
# this is correct to do so, because they are used as indices
|
|
397
|
+
# and will never be modified
|
|
398
|
+
warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
399
|
+
hyperpinn = HYPERPINN(
|
|
400
|
+
mlp=mlp,
|
|
401
|
+
hyper_mlp=hyper_mlp,
|
|
402
|
+
slice_solution=slice_solution,
|
|
403
|
+
eq_type=eq_type,
|
|
404
|
+
input_transform=input_transform,
|
|
405
|
+
output_transform=output_transform,
|
|
406
|
+
hyperparams=hyperparams,
|
|
407
|
+
hypernet_input_size=hypernet_input_size,
|
|
408
|
+
output_slice=None,
|
|
409
|
+
)
|
|
311
410
|
return hyperpinn
|