jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_hyperpinn.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements utility function to create HyperPINNs
|
|
3
|
+
https://arxiv.org/pdf/2111.01008.pdf
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from dataclasses import InitVar
|
|
10
|
+
from typing import Callable, Literal, Self, Union, Any, cast, overload
|
|
11
|
+
from math import prod
|
|
12
|
+
import jax
|
|
13
|
+
import jax.numpy as jnp
|
|
14
|
+
from jaxtyping import Array, Float, PyTree, Key
|
|
15
|
+
import equinox as eqx
|
|
16
|
+
import numpy as onp
|
|
17
|
+
|
|
18
|
+
from jinns.nn._pinn import PINN
|
|
19
|
+
from jinns.nn._mlp import MLP
|
|
20
|
+
from jinns.parameters._params import Params
|
|
21
|
+
from jinns.nn._utils import _PyTree_to_Params
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_param_nb(
|
|
25
|
+
params: PyTree[Array],
|
|
26
|
+
) -> tuple[int, list[int]]:
|
|
27
|
+
"""Returns the number of parameters in a Params object and also
|
|
28
|
+
the cumulative sum when parsing the object.
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
params :
|
|
34
|
+
A Params object.
|
|
35
|
+
"""
|
|
36
|
+
dim_prod_all_arrays = [
|
|
37
|
+
prod(a.shape)
|
|
38
|
+
for a in jax.tree.leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
|
|
39
|
+
]
|
|
40
|
+
return (
|
|
41
|
+
sum(dim_prod_all_arrays),
|
|
42
|
+
onp.cumsum(dim_prod_all_arrays).tolist(),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class HyperPINN(PINN):
|
|
47
|
+
r"""
|
|
48
|
+
An HyperPINN object compatible with the rest of jinns.
|
|
49
|
+
Composed of a PINN and an HYPER network. The HyperPINN is typically
|
|
50
|
+
instanciated using with `create`.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
hyperparams: list[str] = eqx.field(static=True)
|
|
55
|
+
A list of keys from Params.eq_params that will be considered as
|
|
56
|
+
hyperparameters for metamodeling.
|
|
57
|
+
hypernet_input_size: int
|
|
58
|
+
An integer. The input size of the MLP used for the hypernetwork. Must
|
|
59
|
+
be equal to the flattened concatenations for the array of parameters
|
|
60
|
+
designated by the `hyperparams` argument.
|
|
61
|
+
slice_solution : slice
|
|
62
|
+
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
63
|
+
dedicated to the actual equation solution. Default None
|
|
64
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
65
|
+
when the PINN is also used to output equation parameters for example
|
|
66
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
67
|
+
user provided argument takes care of it).
|
|
68
|
+
eq_type : str
|
|
69
|
+
A string with three possibilities.
|
|
70
|
+
"ODE": the HyperPINN is called with one input `t`.
|
|
71
|
+
"statio_PDE": the HyperPINN is called with one input `x`, `x`
|
|
72
|
+
can be high dimensional.
|
|
73
|
+
"nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
|
|
74
|
+
can be high dimensional.
|
|
75
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
76
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
77
|
+
after the `input_transform` function
|
|
78
|
+
input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
79
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
80
|
+
must match the PINN inputs (except for the parameters).
|
|
81
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
82
|
+
and the parameters. Default is no operation.
|
|
83
|
+
output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
|
|
84
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
85
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
86
|
+
Default is no operation.
|
|
87
|
+
mlp : eqx.Module
|
|
88
|
+
The actual neural network instanciated as an eqx.Module.
|
|
89
|
+
hyper_mlp : eqx.Module
|
|
90
|
+
The actual hyper neural network instanciated as an eqx.Module.
|
|
91
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
92
|
+
Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
|
|
93
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
94
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
95
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
96
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
97
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
hyperparams: list[str] = eqx.field(static=True, kw_only=True)
|
|
101
|
+
hypernet_input_size: int = eqx.field(kw_only=True)
|
|
102
|
+
|
|
103
|
+
eqx_hyper_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
104
|
+
|
|
105
|
+
pinn_params_sum: int = eqx.field(init=False, static=True)
|
|
106
|
+
pinn_params_cumsum: list[int] = eqx.field(init=False, static=True)
|
|
107
|
+
|
|
108
|
+
init_params_hyper: HyperPINN = eqx.field(init=False)
|
|
109
|
+
static_hyper: HyperPINN = eqx.field(init=False, static=True)
|
|
110
|
+
|
|
111
|
+
def __post_init__(self, eqx_network, eqx_hyper_network):
|
|
112
|
+
super().__post_init__(
|
|
113
|
+
eqx_network,
|
|
114
|
+
)
|
|
115
|
+
# In addition, we store the PyTree structure of the hypernetwork as well
|
|
116
|
+
self.init_params_hyper, self.static_hyper = eqx.partition(
|
|
117
|
+
eqx_hyper_network, self.filter_spec
|
|
118
|
+
)
|
|
119
|
+
self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.init_params)
|
|
120
|
+
|
|
121
|
+
def _hyper_to_pinn(self, hyper_output: Float[Array, " output_dim"]) -> PINN:
|
|
122
|
+
"""
|
|
123
|
+
From the output of the hypernetwork, transform to a well formed
|
|
124
|
+
parameters for the pinn network (i.e. with the same PyTree structure as
|
|
125
|
+
`self.init_params`)
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
pinn_params_flat = eqx.tree_at(
|
|
129
|
+
jax.tree.leaves, # is_leaf=eqx.is_array argument for jax.tree.leaves
|
|
130
|
+
# is not needed in general when working
|
|
131
|
+
# with eqx.nn.Linear for examples: jax.tree.leaves
|
|
132
|
+
# already returns the array of weights and biases only, since the
|
|
133
|
+
# other stuff (that we do not want to be returned) is marked as
|
|
134
|
+
# static (in eqx.nn.Linear), hence is not part of the leaves.
|
|
135
|
+
# Note, that custom layers should then be properly designed to pass
|
|
136
|
+
# this jax.tree.leaves.
|
|
137
|
+
self.init_params,
|
|
138
|
+
jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return jax.tree.map(
|
|
142
|
+
lambda a, b: a.reshape(b.shape),
|
|
143
|
+
pinn_params_flat,
|
|
144
|
+
self.init_params,
|
|
145
|
+
is_leaf=lambda x: isinstance(x, jnp.ndarray),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
@overload
|
|
149
|
+
@_PyTree_to_Params
|
|
150
|
+
def __call__(
|
|
151
|
+
self,
|
|
152
|
+
inputs: Float[Array, " input_dim"],
|
|
153
|
+
params: PyTree,
|
|
154
|
+
*args,
|
|
155
|
+
**kwargs,
|
|
156
|
+
) -> Float[Array, " output_dim"]: ...
|
|
157
|
+
|
|
158
|
+
@_PyTree_to_Params
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
inputs: Float[Array, " input_dim"],
|
|
162
|
+
params: Params[Array],
|
|
163
|
+
*args,
|
|
164
|
+
**kwargs,
|
|
165
|
+
) -> Float[Array, " output_dim"]:
|
|
166
|
+
"""
|
|
167
|
+
Evaluate the HyperPINN on some inputs with some params.
|
|
168
|
+
|
|
169
|
+
Note that that thanks to the decorator, params can also directly be the
|
|
170
|
+
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
171
|
+
"""
|
|
172
|
+
if len(inputs.shape) == 0:
|
|
173
|
+
# This can happen often when the user directly provides some
|
|
174
|
+
# collocation points (eg for plotting, whithout using
|
|
175
|
+
# DataGenerators)
|
|
176
|
+
inputs = inputs[None]
|
|
177
|
+
|
|
178
|
+
# try:
|
|
179
|
+
hyper = eqx.combine(params.nn_params, self.static_hyper)
|
|
180
|
+
# except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
181
|
+
# hyper = eqx.combine(params, self.static_hyper)
|
|
182
|
+
|
|
183
|
+
eq_params_batch = jnp.concatenate(
|
|
184
|
+
[params.eq_params[k].flatten() for k in self.hyperparams],
|
|
185
|
+
axis=0,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
hyper_output = hyper(eq_params_batch) # type: ignore
|
|
189
|
+
|
|
190
|
+
pinn_params = self._hyper_to_pinn(hyper_output)
|
|
191
|
+
|
|
192
|
+
pinn = eqx.combine(pinn_params, self.static)
|
|
193
|
+
res = self.eval(pinn, self.input_transform(inputs, params), *args, **kwargs)
|
|
194
|
+
|
|
195
|
+
res = self.output_transform(inputs, res.squeeze(), params)
|
|
196
|
+
|
|
197
|
+
# force (1,) output for non vectorial solution (consistency)
|
|
198
|
+
if not res.shape:
|
|
199
|
+
return jnp.expand_dims(res, axis=-1)
|
|
200
|
+
return res
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def create(
|
|
204
|
+
cls,
|
|
205
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
206
|
+
hyperparams: list[str],
|
|
207
|
+
hypernet_input_size: int,
|
|
208
|
+
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
209
|
+
eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
|
|
210
|
+
key: Key = None,
|
|
211
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
212
|
+
eqx_list_hyper: (
|
|
213
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
|
|
214
|
+
) = None,
|
|
215
|
+
input_transform: (
|
|
216
|
+
Callable[
|
|
217
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
218
|
+
Float[Array, " output_dim"],
|
|
219
|
+
]
|
|
220
|
+
| None
|
|
221
|
+
) = None,
|
|
222
|
+
output_transform: (
|
|
223
|
+
Callable[
|
|
224
|
+
[
|
|
225
|
+
Float[Array, " input_dim"],
|
|
226
|
+
Float[Array, " output_dim"],
|
|
227
|
+
Params[Array],
|
|
228
|
+
],
|
|
229
|
+
Float[Array, " output_dim"],
|
|
230
|
+
]
|
|
231
|
+
| None
|
|
232
|
+
) = None,
|
|
233
|
+
slice_solution: slice | None = None,
|
|
234
|
+
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
235
|
+
) -> tuple[Self, HyperPINN]:
|
|
236
|
+
r"""
|
|
237
|
+
Utility function to create a standard PINN neural network with the equinox
|
|
238
|
+
library.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
key
|
|
243
|
+
A JAX random key that will be used to initialize the network
|
|
244
|
+
parameters.
|
|
245
|
+
eq_type
|
|
246
|
+
A string with three possibilities.
|
|
247
|
+
"ODE": the HyperPINN is called with one input `t`.
|
|
248
|
+
"statio_PDE": the HyperPINN is called with one input `x`, `x`
|
|
249
|
+
can be high dimensional.
|
|
250
|
+
"nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
|
|
251
|
+
can be high dimensional.
|
|
252
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
253
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
254
|
+
after the `input_transform` function
|
|
255
|
+
hyperparams
|
|
256
|
+
A list of keys from Params.eq_params that will be considered as
|
|
257
|
+
hyperparameters for metamodeling.
|
|
258
|
+
hypernet_input_size
|
|
259
|
+
An integer. The input size of the MLP used for the hypernetwork. Must
|
|
260
|
+
be equal to the flattened concatenations for the array of parameters
|
|
261
|
+
designated by the `hyperparams` argument.
|
|
262
|
+
eqx_network
|
|
263
|
+
Default is None. A eqx.nn.MLP for the base network that will be wrapped inside
|
|
264
|
+
our PINN_MLP object in order to make it easily jinns compatible.
|
|
265
|
+
eqx_hyper_network
|
|
266
|
+
Default is None. A eqx.nn.MLP for the hyper network that will be wrapped inside
|
|
267
|
+
our PINN_MLP object in order to make it easily jinns compatible.
|
|
268
|
+
key
|
|
269
|
+
Default is None. Must be provided with `eqx_list` and
|
|
270
|
+
`eqx_list_hyper` if `eqx_network` or `eqx_hyper_network`
|
|
271
|
+
is not provided. A JAX random key that will be used to initialize the network
|
|
272
|
+
parameters.
|
|
273
|
+
eqx_list
|
|
274
|
+
Default is None. Must be provided if `eqx_network` or
|
|
275
|
+
`eqx_hyper_network`
|
|
276
|
+
is not provided.
|
|
277
|
+
A tuple of tuples of successive equinox modules and activation functions to
|
|
278
|
+
describe the base network architecture. The inner tuples must have the eqx module or
|
|
279
|
+
activation function as first item, other items represent arguments
|
|
280
|
+
that could be required (eg. the size of the layer).
|
|
281
|
+
The `key` argument need not be given.
|
|
282
|
+
Thus typical example is `eqx_list=
|
|
283
|
+
((eqx.nn.Linear, 2, 20),
|
|
284
|
+
(jax.nn.tanh,),
|
|
285
|
+
(eqx.nn.Linear, 20, 20),
|
|
286
|
+
(jax.nn.tanh,),
|
|
287
|
+
(eqx.nn.Linear, 20, 20),
|
|
288
|
+
(jax.nn.tanh,),
|
|
289
|
+
(eqx.nn.Linear, 20, 1)
|
|
290
|
+
)`.
|
|
291
|
+
eqx_list_hyper
|
|
292
|
+
Default is None. Must be provided if `eqx_network` or
|
|
293
|
+
`eqx_hyper_network`
|
|
294
|
+
is not provided.
|
|
295
|
+
A tuple of tuples of successive equinox modules and activation functions to
|
|
296
|
+
describe the hyper network architecture. The inner tuples must have the eqx module or
|
|
297
|
+
activation function as first item, other items represent arguments
|
|
298
|
+
that could be required (eg. the size of the layer).
|
|
299
|
+
The `key` argument need not be given.
|
|
300
|
+
Thus typical example is `eqx_list=
|
|
301
|
+
((eqx.nn.Linear, 2, 20),
|
|
302
|
+
(jax.nn.tanh,),
|
|
303
|
+
(eqx.nn.Linear, 20, 20),
|
|
304
|
+
(jax.nn.tanh,),
|
|
305
|
+
(eqx.nn.Linear, 20, 20),
|
|
306
|
+
(jax.nn.tanh,),
|
|
307
|
+
(eqx.nn.Linear, 20, 1)
|
|
308
|
+
)`.
|
|
309
|
+
input_transform
|
|
310
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
311
|
+
must match the PINN inputs (except for the parameters).
|
|
312
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
313
|
+
and the parameters. Default is no operation.
|
|
314
|
+
output_transform
|
|
315
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
316
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
317
|
+
Default is no operation.
|
|
318
|
+
slice_solution
|
|
319
|
+
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
320
|
+
dedicated to the actual equation solution. Default None
|
|
321
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
322
|
+
when the PINN is also used to output equation parameters for example
|
|
323
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
324
|
+
user provided argument takes care of it).
|
|
325
|
+
eqx_list_hyper
|
|
326
|
+
Same as eqx_list but for the hypernetwork. Default is None, i.e., we
|
|
327
|
+
use the same architecture as the PINN, up to the number of inputs and
|
|
328
|
+
ouputs. Note that the number of inputs must be of the hypernetwork must
|
|
329
|
+
be equal to the flattened concatenations for the array of parameters
|
|
330
|
+
designated by the `hyperparams` argument;
|
|
331
|
+
and the number of outputs must be equal to the number
|
|
332
|
+
of parameters in the pinn network
|
|
333
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
334
|
+
Default is None which leads to `eqx.is_inexact_array` in the class
|
|
335
|
+
instanciation. This tells Jinns what to consider as
|
|
336
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
337
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
338
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
339
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
340
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
341
|
+
|
|
342
|
+
Returns
|
|
343
|
+
-------
|
|
344
|
+
hyperpinn
|
|
345
|
+
A HyperPINN instance or, when `shared_pinn_ouput` is not None,
|
|
346
|
+
a list of HyperPINN instances with the same structure is returned,
|
|
347
|
+
only differing by there final slicing of the network output.
|
|
348
|
+
hyperpinn.init_params
|
|
349
|
+
The initial set of parameters for the HyperPINN or a list of the latter
|
|
350
|
+
when `shared_pinn_ouput` is not None.
|
|
351
|
+
|
|
352
|
+
"""
|
|
353
|
+
if eqx_network is None or eqx_hyper_network is None:
|
|
354
|
+
if eqx_list is None or key is None or eqx_list_hyper is None:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
"If eqx_network is None or eqx_hyper_network is None, then"
|
|
357
|
+
" key and eqx_list and eqx_hyper_network must be provided"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
### Now we finetune the hypernetwork architecture
|
|
361
|
+
|
|
362
|
+
key, subkey = jax.random.split(key, 2)
|
|
363
|
+
# with warnings.catch_warnings():
|
|
364
|
+
# warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
365
|
+
eqx_network = MLP(key=subkey, eqx_list=eqx_list)
|
|
366
|
+
# quick partitioning to get the params to get the correct number of neurons
|
|
367
|
+
# for the last layer of hyper network
|
|
368
|
+
params_mlp, _ = eqx.partition(eqx_network, eqx.is_inexact_array)
|
|
369
|
+
pinn_params_sum, _ = _get_param_nb(params_mlp)
|
|
370
|
+
# the number of parameters for the pinn will be the number of ouputs
|
|
371
|
+
# for the hyper network
|
|
372
|
+
if len(eqx_list_hyper[-1]) > 1:
|
|
373
|
+
eqx_list_hyper = eqx_list_hyper[:-1] + (
|
|
374
|
+
(eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
eqx_list_hyper = cast(
|
|
378
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
379
|
+
(
|
|
380
|
+
eqx_list_hyper[:-2]
|
|
381
|
+
+ ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
|
|
382
|
+
+ eqx_list_hyper[-1]
|
|
383
|
+
),
|
|
384
|
+
)
|
|
385
|
+
if len(eqx_list_hyper[0]) > 1:
|
|
386
|
+
eqx_list_hyper = (
|
|
387
|
+
(
|
|
388
|
+
(eqx_list_hyper[0][0],)
|
|
389
|
+
+ (hypernet_input_size,)
|
|
390
|
+
+ (eqx_list_hyper[0][2],)
|
|
391
|
+
),
|
|
392
|
+
) + eqx_list_hyper[1:]
|
|
393
|
+
else:
|
|
394
|
+
eqx_list_hyper = cast(
|
|
395
|
+
tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
396
|
+
(
|
|
397
|
+
eqx_list_hyper[0]
|
|
398
|
+
+ (
|
|
399
|
+
(
|
|
400
|
+
(eqx_list_hyper[1][0],)
|
|
401
|
+
+ (hypernet_input_size,)
|
|
402
|
+
+ (eqx_list_hyper[1][2],) # type: ignore because we suppose that the second element of tuple is nec.of length > 1 since we expect smth like eqx.nn.Linear
|
|
403
|
+
),
|
|
404
|
+
)
|
|
405
|
+
+ eqx_list_hyper[2:]
|
|
406
|
+
),
|
|
407
|
+
)
|
|
408
|
+
key, subkey = jax.random.split(key, 2)
|
|
409
|
+
# with warnings.catch_warnings():
|
|
410
|
+
# warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
411
|
+
eqx_hyper_network = cast(MLP, MLP(key=subkey, eqx_list=eqx_list_hyper))
|
|
412
|
+
|
|
413
|
+
### End of finetuning the hypernetwork architecture
|
|
414
|
+
|
|
415
|
+
with warnings.catch_warnings():
|
|
416
|
+
# Catch the equinox warning because we put the number of
|
|
417
|
+
# parameters as static while being jnp.Array. This this time
|
|
418
|
+
# this is correct to do so, because they are used as indices
|
|
419
|
+
# and will never be modified
|
|
420
|
+
warnings.filterwarnings(
|
|
421
|
+
"ignore", message="A JAX array is being set as static!"
|
|
422
|
+
)
|
|
423
|
+
hyperpinn = cls(
|
|
424
|
+
eqx_network=eqx_network,
|
|
425
|
+
eqx_hyper_network=eqx_hyper_network,
|
|
426
|
+
slice_solution=slice_solution, # type: ignore
|
|
427
|
+
eq_type=eq_type,
|
|
428
|
+
input_transform=input_transform, # type: ignore
|
|
429
|
+
output_transform=output_transform, # type: ignore
|
|
430
|
+
hyperparams=hyperparams,
|
|
431
|
+
hypernet_input_size=hypernet_input_size,
|
|
432
|
+
filter_spec=filter_spec,
|
|
433
|
+
)
|
|
434
|
+
return hyperpinn, hyperpinn.init_params_hyper
|
jinns/nn/_mlp.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements utility function to create PINNs
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Literal, Self, Union, Any, TYPE_CHECKING, cast
|
|
8
|
+
from dataclasses import InitVar
|
|
9
|
+
import jax
|
|
10
|
+
import equinox as eqx
|
|
11
|
+
from typing import Protocol
|
|
12
|
+
from jaxtyping import Array, Key, PyTree, Float
|
|
13
|
+
|
|
14
|
+
from jinns.parameters._params import Params
|
|
15
|
+
from jinns.nn._pinn import PINN
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
|
|
19
|
+
class CallableMLPModule(Protocol):
|
|
20
|
+
"""
|
|
21
|
+
Basically just a way to add a __call__ to an eqx.Module.
|
|
22
|
+
https://github.com/patrick-kidger/equinox/issues/1002
|
|
23
|
+
We chose the strutural subtyping of protocols instead of subclassing an
|
|
24
|
+
eqx.Module just to add a __call__ here
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __call__(self, *_, **__) -> Array: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MLP(eqx.Module):
|
|
31
|
+
"""
|
|
32
|
+
Custom MLP equinox module from a key and a eqx_list
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
key : InitVar[Key]
|
|
37
|
+
A jax random key for the layer initializations.
|
|
38
|
+
eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
|
|
39
|
+
A tuple of tuples of successive equinox modules and activation functions to
|
|
40
|
+
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
41
|
+
activation function as first item, other items represents arguments
|
|
42
|
+
that could be required (eg. the size of the layer).
|
|
43
|
+
The `key` argument need not be given.
|
|
44
|
+
Thus typical example is `eqx_list=
|
|
45
|
+
((eqx.nn.Linear, 2, 20),
|
|
46
|
+
(jax.nn.tanh,),
|
|
47
|
+
(eqx.nn.Linear, 20, 20),
|
|
48
|
+
(jax.nn.tanh,),
|
|
49
|
+
(eqx.nn.Linear, 20, 20),
|
|
50
|
+
(jax.nn.tanh,),
|
|
51
|
+
(eqx.nn.Linear, 20, 1)
|
|
52
|
+
)`.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
56
|
+
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
57
|
+
eqx.field(kw_only=True)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# NOTE that the following should NOT be declared as static otherwise the
|
|
61
|
+
# eqx.partition that we use in the PINN module will misbehave
|
|
62
|
+
layers: list[CallableMLPModule | Callable[[Array], Array]] = eqx.field(init=False)
|
|
63
|
+
|
|
64
|
+
def __post_init__(self, key, eqx_list):
|
|
65
|
+
self.layers = []
|
|
66
|
+
# nb_keys_required = sum(1 if len(l) > 1 else 0 for l in eqx_list)
|
|
67
|
+
# keys = jax.random.split(key, nb_keys_required)
|
|
68
|
+
# we need a global split
|
|
69
|
+
# before the loop to maintain strict equivalency with eqx.nn.MLP
|
|
70
|
+
# for debugging purpose
|
|
71
|
+
k = 0
|
|
72
|
+
for l in eqx_list:
|
|
73
|
+
if len(l) == 1:
|
|
74
|
+
self.layers.append(l[0])
|
|
75
|
+
else:
|
|
76
|
+
key, subkey = jax.random.split(key, 2) # nb_keys_required)
|
|
77
|
+
self.layers.append(l[0](*l[1:], key=subkey))
|
|
78
|
+
k += 1
|
|
79
|
+
|
|
80
|
+
def __call__(self, t: Float[Array, " input_dim"]) -> Float[Array, " output_dim"]:
|
|
81
|
+
for layer in self.layers:
|
|
82
|
+
t = layer(t)
|
|
83
|
+
return t
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class PINN_MLP(PINN):
|
|
87
|
+
"""
|
|
88
|
+
An implementable PINN based on a MLP architecture
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
# Here we could have a more complex __call__ method that redefined the
|
|
92
|
+
# parent's __call__. But there is no need for the simple PINN_MLP
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def create(
|
|
96
|
+
cls,
|
|
97
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
98
|
+
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
99
|
+
key: Key = None,
|
|
100
|
+
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
101
|
+
input_transform: (
|
|
102
|
+
Callable[
|
|
103
|
+
[Float[Array, " input_dim"], Params[Array]],
|
|
104
|
+
Float[Array, " output_dim"],
|
|
105
|
+
]
|
|
106
|
+
| None
|
|
107
|
+
) = None,
|
|
108
|
+
output_transform: (
|
|
109
|
+
Callable[
|
|
110
|
+
[
|
|
111
|
+
Float[Array, " input_dim"],
|
|
112
|
+
Float[Array, " output_dim"],
|
|
113
|
+
Params[Array],
|
|
114
|
+
],
|
|
115
|
+
Float[Array, " output_dim"],
|
|
116
|
+
]
|
|
117
|
+
| None
|
|
118
|
+
) = None,
|
|
119
|
+
slice_solution: slice | None = None,
|
|
120
|
+
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
121
|
+
) -> tuple[Self, PINN]:
|
|
122
|
+
r"""
|
|
123
|
+
Instanciate standard PINN MLP object. The actual NN is either passed as
|
|
124
|
+
a eqx.nn.MLP (`eqx_network` argument) or constructed as a custom
|
|
125
|
+
jinns.nn.MLP when `key` and `eqx_list` is provided.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
eq_type
|
|
130
|
+
A string with three possibilities.
|
|
131
|
+
"ODE": the MLP is called with one input `t`.
|
|
132
|
+
"statio_PDE": the MLP is called with one input `x`, `x`
|
|
133
|
+
can be high dimensional.
|
|
134
|
+
"nonstatio_PDE": the MLP is called with two inputs `t` and `x`, `x`
|
|
135
|
+
can be high dimensional.
|
|
136
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
137
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
138
|
+
after the `input_transform` function.
|
|
139
|
+
eqx_network
|
|
140
|
+
Default is None. A eqx.nn.MLP object that will be wrapped inside
|
|
141
|
+
our PINN_MLP object in order to make it easily jinns compatible.
|
|
142
|
+
key
|
|
143
|
+
Default is None. Must be provided with `eqx_list` if `eqx_network`
|
|
144
|
+
is not provided. A JAX random key that will be used to initialize the network
|
|
145
|
+
parameters.
|
|
146
|
+
eqx_list
|
|
147
|
+
Default is None. Must be provided if `eqx_network`
|
|
148
|
+
is not provided. A tuple of tuples of successive equinox modules and activation
|
|
149
|
+
functions to describe the MLP architecture. The inner tuples must have
|
|
150
|
+
the eqx module or activation function as first item, other items
|
|
151
|
+
represent arguments that could be required (eg. the size of the layer).
|
|
152
|
+
|
|
153
|
+
The `key` argument do not need to be given.
|
|
154
|
+
|
|
155
|
+
A typical example is `eqx_list = (
|
|
156
|
+
(eqx.nn.Linear, input_dim, 20),
|
|
157
|
+
(jax.nn.tanh,),
|
|
158
|
+
(eqx.nn.Linear, 20, 20),
|
|
159
|
+
(jax.nn.tanh,),
|
|
160
|
+
(eqx.nn.Linear, 20, 20),
|
|
161
|
+
(jax.nn.tanh,),
|
|
162
|
+
(eqx.nn.Linear, 20, output_dim)
|
|
163
|
+
)`.
|
|
164
|
+
input_transform
|
|
165
|
+
A function that will be called before entering the MLP. Its output(s)
|
|
166
|
+
must match the MLP inputs (except for the parameters).
|
|
167
|
+
Its inputs are the MLP inputs (`t` and/or `x` concatenated together)
|
|
168
|
+
and the parameters. Default is no operation.
|
|
169
|
+
output_transform
|
|
170
|
+
A function with arguments begin the same input as the MLP, the MLP
|
|
171
|
+
output and the parameter. This function will be called after exiting
|
|
172
|
+
the MLP.
|
|
173
|
+
Default is no operation.
|
|
174
|
+
slice_solution
|
|
175
|
+
A jnp.s\_ object which indicates which axis of the MLP output is
|
|
176
|
+
dedicated to the actual equation solution. Default None
|
|
177
|
+
means that slice_solution = the whole MLP output. This argument is
|
|
178
|
+
useful when the MLP is also used to output equation parameters for
|
|
179
|
+
example Note that it must be a slice and not an integer (a
|
|
180
|
+
preprocessing of the user provided argument takes care of it).
|
|
181
|
+
|
|
182
|
+
filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
|
|
183
|
+
Default is None which leads to `eqx.is_inexact_array` in the class
|
|
184
|
+
instanciation. This tells Jinns what to consider as
|
|
185
|
+
a trainable parameter. Quoting from equinox documentation:
|
|
186
|
+
a PyTree whose structure should be a prefix of the structure of pytree.
|
|
187
|
+
Each of its leaves should either be 1) True, in which case the leaf or
|
|
188
|
+
subtree is kept; 2) False, in which case the leaf or subtree is
|
|
189
|
+
replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
mlp
|
|
194
|
+
A MLP instance or, when `shared_pinn_ouput` is not None,
|
|
195
|
+
a list of MLP instances with the same structure is returned,
|
|
196
|
+
only differing by there final slicing of the network output.
|
|
197
|
+
mlp.init_params
|
|
198
|
+
An initial set of parameters for the MLP or a list of the latter
|
|
199
|
+
when `shared_pinn_ouput` is not None.
|
|
200
|
+
|
|
201
|
+
"""
|
|
202
|
+
if eqx_network is None:
|
|
203
|
+
if eqx_list is None or key is None:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
"If eqx_network is None, then key and eqx_list must be provided"
|
|
206
|
+
)
|
|
207
|
+
eqx_network = cast(MLP, MLP(key=key, eqx_list=eqx_list))
|
|
208
|
+
|
|
209
|
+
mlp = cls(
|
|
210
|
+
eqx_network=eqx_network,
|
|
211
|
+
slice_solution=slice_solution, # type: ignore
|
|
212
|
+
eq_type=eq_type,
|
|
213
|
+
input_transform=input_transform, # type: ignore
|
|
214
|
+
output_transform=output_transform, # type: ignore
|
|
215
|
+
filter_spec=filter_spec,
|
|
216
|
+
)
|
|
217
|
+
return mlp, mlp.init_params
|