jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/utils/_pinn.py
DELETED
|
@@ -1,324 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Implements utility function to create PINNs
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Callable, Literal
|
|
6
|
-
from dataclasses import InitVar
|
|
7
|
-
import jax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
|
-
import equinox as eqx
|
|
10
|
-
|
|
11
|
-
from jaxtyping import Array, Key, PyTree, Float
|
|
12
|
-
|
|
13
|
-
from jinns.parameters._params import Params, ParamsDict
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class _MLP(eqx.Module):
|
|
17
|
-
"""
|
|
18
|
-
Class to construct an equinox module from a key and a eqx_list. To be used
|
|
19
|
-
in pair with the function `create_PINN`.
|
|
20
|
-
|
|
21
|
-
Parameters
|
|
22
|
-
----------
|
|
23
|
-
key : InitVar[Key]
|
|
24
|
-
A jax random key for the layer initializations.
|
|
25
|
-
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
26
|
-
A tuple of tuples of successive equinox modules and activation functions to
|
|
27
|
-
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
28
|
-
activation function as first item, other items represents arguments
|
|
29
|
-
that could be required (eg. the size of the layer).
|
|
30
|
-
The `key` argument need not be given.
|
|
31
|
-
Thus typical example is `eqx_list=
|
|
32
|
-
((eqx.nn.Linear, 2, 20),
|
|
33
|
-
jax.nn.tanh,
|
|
34
|
-
(eqx.nn.Linear, 20, 20),
|
|
35
|
-
jax.nn.tanh,
|
|
36
|
-
(eqx.nn.Linear, 20, 20),
|
|
37
|
-
jax.nn.tanh,
|
|
38
|
-
(eqx.nn.Linear, 20, 1)
|
|
39
|
-
)`.
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
43
|
-
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
|
|
44
|
-
kw_only=True
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
# NOTE that the following should NOT be declared as static otherwise the
|
|
48
|
-
# eqx.partition that we use in the PINN module will misbehave
|
|
49
|
-
layers: list[eqx.Module] = eqx.field(init=False)
|
|
50
|
-
|
|
51
|
-
def __post_init__(self, key, eqx_list):
|
|
52
|
-
self.layers = []
|
|
53
|
-
for l in eqx_list:
|
|
54
|
-
if len(l) == 1:
|
|
55
|
-
self.layers.append(l[0])
|
|
56
|
-
else:
|
|
57
|
-
key, subkey = jax.random.split(key, 2)
|
|
58
|
-
self.layers.append(l[0](*l[1:], key=subkey))
|
|
59
|
-
|
|
60
|
-
def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
|
|
61
|
-
for layer in self.layers:
|
|
62
|
-
t = layer(t)
|
|
63
|
-
return t
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class PINN(eqx.Module):
|
|
67
|
-
r"""
|
|
68
|
-
A PINN object, i.e., a neural network compatible with the rest of jinns.
|
|
69
|
-
This is typically created with `create_PINN` which creates iternally a
|
|
70
|
-
`_MLP` object. However, a user could directly creates their PINN using this
|
|
71
|
-
class by passing a eqx.Module (for argument `mlp`)
|
|
72
|
-
that plays the role of the NN and that is
|
|
73
|
-
already instanciated.
|
|
74
|
-
|
|
75
|
-
Parameters
|
|
76
|
-
----------
|
|
77
|
-
slice_solution : slice
|
|
78
|
-
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
79
|
-
dedicated to the actual equation solution. Default None
|
|
80
|
-
means that slice_solution = the whole PINN output. This argument is useful
|
|
81
|
-
when the PINN is also used to output equation parameters for example
|
|
82
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
83
|
-
user provided argument takes care of it).
|
|
84
|
-
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
85
|
-
A string with three possibilities.
|
|
86
|
-
"ODE": the PINN is called with one input `t`.
|
|
87
|
-
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
88
|
-
can be high dimensional.
|
|
89
|
-
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
90
|
-
can be high dimensional.
|
|
91
|
-
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
92
|
-
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
93
|
-
after the `input_transform` function.
|
|
94
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
95
|
-
A function that will be called before entering the PINN. Its output(s)
|
|
96
|
-
must match the PINN inputs (except for the parameters).
|
|
97
|
-
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
98
|
-
and the parameters. Default is no operation.
|
|
99
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
100
|
-
A function with arguments begin the same input as the PINN, the PINN
|
|
101
|
-
output and the parameter. This function will be called after exiting the PINN.
|
|
102
|
-
Default is no operation.
|
|
103
|
-
output_slice : slice, default=None
|
|
104
|
-
A jnp.s\_[] to determine the different dimension for the PINN.
|
|
105
|
-
See `shared_pinn_outputs` argument of `create_PINN`.
|
|
106
|
-
mlp : eqx.Module
|
|
107
|
-
The actual neural network instanciated as an eqx.Module.
|
|
108
|
-
"""
|
|
109
|
-
|
|
110
|
-
slice_solution: slice = eqx.field(static=True, kw_only=True)
|
|
111
|
-
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
|
|
112
|
-
static=True, kw_only=True
|
|
113
|
-
)
|
|
114
|
-
input_transform: Callable[
|
|
115
|
-
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
116
|
-
] = eqx.field(static=True, kw_only=True)
|
|
117
|
-
output_transform: Callable[
|
|
118
|
-
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
119
|
-
Float[Array, "output_dim"],
|
|
120
|
-
] = eqx.field(static=True, kw_only=True)
|
|
121
|
-
output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
122
|
-
|
|
123
|
-
mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
124
|
-
|
|
125
|
-
params: PyTree = eqx.field(init=False)
|
|
126
|
-
static: PyTree = eqx.field(init=False, static=True)
|
|
127
|
-
|
|
128
|
-
def __post_init__(self, mlp):
|
|
129
|
-
self.params, self.static = eqx.partition(mlp, eqx.is_inexact_array)
|
|
130
|
-
|
|
131
|
-
@property
|
|
132
|
-
def init_params(self) -> PyTree:
|
|
133
|
-
"""
|
|
134
|
-
Returns an initial set of parameters
|
|
135
|
-
"""
|
|
136
|
-
return self.params
|
|
137
|
-
|
|
138
|
-
def __call__(
|
|
139
|
-
self,
|
|
140
|
-
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
141
|
-
params: Params | ParamsDict | PyTree,
|
|
142
|
-
) -> Float[Array, "output_dim"]:
|
|
143
|
-
"""
|
|
144
|
-
Evaluate the PINN on some inputs with some params.
|
|
145
|
-
"""
|
|
146
|
-
if len(inputs.shape) == 0:
|
|
147
|
-
# This can happen often when the user directly provides some
|
|
148
|
-
# collocation points (eg for plotting, whithout using
|
|
149
|
-
# DataGenerators)
|
|
150
|
-
inputs = inputs[None]
|
|
151
|
-
|
|
152
|
-
try:
|
|
153
|
-
model = eqx.combine(params.nn_params, self.static)
|
|
154
|
-
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
155
|
-
model = eqx.combine(params, self.static)
|
|
156
|
-
res = self.output_transform(
|
|
157
|
-
inputs, model(self.input_transform(inputs, params)).squeeze(), params
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
if self.output_slice is not None:
|
|
161
|
-
res = res[self.output_slice]
|
|
162
|
-
|
|
163
|
-
## force (1,) output for non vectorial solution (consistency)
|
|
164
|
-
if not res.shape:
|
|
165
|
-
return jnp.expand_dims(res, axis=-1)
|
|
166
|
-
return res
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def create_PINN(
|
|
170
|
-
key: Key,
|
|
171
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
172
|
-
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
173
|
-
dim_x: int = 0,
|
|
174
|
-
input_transform: Callable[
|
|
175
|
-
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
176
|
-
] = None,
|
|
177
|
-
output_transform: Callable[
|
|
178
|
-
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
179
|
-
Float[Array, "output_dim"],
|
|
180
|
-
] = None,
|
|
181
|
-
shared_pinn_outputs: tuple[slice] = None,
|
|
182
|
-
slice_solution: slice = None,
|
|
183
|
-
) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
|
|
184
|
-
r"""
|
|
185
|
-
Utility function to create a standard PINN neural network with the equinox
|
|
186
|
-
library.
|
|
187
|
-
|
|
188
|
-
Parameters
|
|
189
|
-
----------
|
|
190
|
-
key
|
|
191
|
-
A JAX random key that will be used to initialize the network
|
|
192
|
-
parameters.
|
|
193
|
-
eqx_list
|
|
194
|
-
A tuple of tuples of successive equinox modules and activation
|
|
195
|
-
functions to describe the PINN architecture. The inner tuples must have
|
|
196
|
-
the eqx module or activation function as first item, other items
|
|
197
|
-
represent arguments that could be required (eg. the size of the layer).
|
|
198
|
-
|
|
199
|
-
The `key` argument do not need to be given.
|
|
200
|
-
|
|
201
|
-
A typical example is `eqx_list = (
|
|
202
|
-
(eqx.nn.Linear, input_dim, 20),
|
|
203
|
-
(jax.nn.tanh,),
|
|
204
|
-
(eqx.nn.Linear, 20, 20),
|
|
205
|
-
(jax.nn.tanh,),
|
|
206
|
-
(eqx.nn.Linear, 20, 20),
|
|
207
|
-
(jax.nn.tanh,),
|
|
208
|
-
(eqx.nn.Linear, 20, output_dim)
|
|
209
|
-
)`.
|
|
210
|
-
eq_type
|
|
211
|
-
A string with three possibilities.
|
|
212
|
-
"ODE": the PINN is called with one input `t`.
|
|
213
|
-
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
214
|
-
can be high dimensional.
|
|
215
|
-
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
216
|
-
can be high dimensional.
|
|
217
|
-
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
218
|
-
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
219
|
-
after the `input_transform` function.
|
|
220
|
-
dim_x
|
|
221
|
-
An integer. The dimension of `x`. Default `0`.
|
|
222
|
-
input_transform
|
|
223
|
-
A function that will be called before entering the PINN. Its output(s)
|
|
224
|
-
must match the PINN inputs (except for the parameters).
|
|
225
|
-
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
226
|
-
and the parameters. Default is no operation.
|
|
227
|
-
output_transform
|
|
228
|
-
A function with arguments begin the same input as the PINN, the PINN
|
|
229
|
-
output and the parameter. This function will be called after exiting
|
|
230
|
-
the PINN.
|
|
231
|
-
Default is no operation.
|
|
232
|
-
shared_pinn_outputs
|
|
233
|
-
Default is None, for a stantard PINN.
|
|
234
|
-
A tuple of jnp.s\_[] (slices) to determine the different output for each
|
|
235
|
-
network. In this case we return a list of PINNs, one for each output in
|
|
236
|
-
shared_pinn_outputs. This is useful to create PINNs that share the
|
|
237
|
-
same network and same parameters; **the user must then use the same
|
|
238
|
-
parameter set in their manipulation**.
|
|
239
|
-
See the notebook 2D Navier Stokes in pipeflow with metamodel for an
|
|
240
|
-
example using this option.
|
|
241
|
-
slice_solution
|
|
242
|
-
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
243
|
-
dedicated to the actual equation solution. Default None
|
|
244
|
-
means that slice_solution = the whole PINN output. This argument is
|
|
245
|
-
useful when the PINN is also used to output equation parameters for
|
|
246
|
-
example Note that it must be a slice and not an integer (a
|
|
247
|
-
preprocessing of the user provided argument takes care of it).
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
Returns
|
|
251
|
-
-------
|
|
252
|
-
pinn
|
|
253
|
-
A PINN instance or, when `shared_pinn_ouput` is not None,
|
|
254
|
-
a list of PINN instances with the same structure is returned,
|
|
255
|
-
only differing by there final slicing of the network output.
|
|
256
|
-
pinn.init_params
|
|
257
|
-
An initial set of parameters for the PINN or a list of the latter
|
|
258
|
-
when `shared_pinn_ouput` is not None.
|
|
259
|
-
|
|
260
|
-
Raises
|
|
261
|
-
------
|
|
262
|
-
RuntimeError
|
|
263
|
-
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
264
|
-
"nonstatio_PDE"]`
|
|
265
|
-
RuntimeError
|
|
266
|
-
If we have a `dim_x > 0` and `eq_type == "ODE"`
|
|
267
|
-
or if we have a `dim_x = 0` and `eq_type != "ODE"`
|
|
268
|
-
"""
|
|
269
|
-
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
270
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
271
|
-
|
|
272
|
-
if eq_type == "ODE" and dim_x != 0:
|
|
273
|
-
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
274
|
-
|
|
275
|
-
if eq_type != "ODE" and dim_x == 0:
|
|
276
|
-
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
277
|
-
|
|
278
|
-
try:
|
|
279
|
-
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
280
|
-
# last layer
|
|
281
|
-
except IndexError:
|
|
282
|
-
nb_outputs_declared = eqx_list[-2][2]
|
|
283
|
-
|
|
284
|
-
if slice_solution is None:
|
|
285
|
-
slice_solution = jnp.s_[0:nb_outputs_declared]
|
|
286
|
-
if isinstance(slice_solution, int):
|
|
287
|
-
# rewrite it as a slice to ensure that axis does not disappear when
|
|
288
|
-
# indexing
|
|
289
|
-
slice_solution = jnp.s_[slice_solution : slice_solution + 1]
|
|
290
|
-
|
|
291
|
-
if input_transform is None:
|
|
292
|
-
|
|
293
|
-
def input_transform(_in, _params):
|
|
294
|
-
return _in
|
|
295
|
-
|
|
296
|
-
if output_transform is None:
|
|
297
|
-
|
|
298
|
-
def output_transform(_in_pinn, _out_pinn, _params):
|
|
299
|
-
return _out_pinn
|
|
300
|
-
|
|
301
|
-
mlp = _MLP(key=key, eqx_list=eqx_list)
|
|
302
|
-
|
|
303
|
-
if shared_pinn_outputs is not None:
|
|
304
|
-
pinns = []
|
|
305
|
-
for output_slice in shared_pinn_outputs:
|
|
306
|
-
pinn = PINN(
|
|
307
|
-
mlp=mlp,
|
|
308
|
-
slice_solution=slice_solution,
|
|
309
|
-
eq_type=eq_type,
|
|
310
|
-
input_transform=input_transform,
|
|
311
|
-
output_transform=output_transform,
|
|
312
|
-
output_slice=output_slice,
|
|
313
|
-
)
|
|
314
|
-
pinns.append(pinn)
|
|
315
|
-
return pinns, [p.init_params for p in pinns]
|
|
316
|
-
pinn = PINN(
|
|
317
|
-
mlp=mlp,
|
|
318
|
-
slice_solution=slice_solution,
|
|
319
|
-
eq_type=eq_type,
|
|
320
|
-
input_transform=input_transform,
|
|
321
|
-
output_transform=output_transform,
|
|
322
|
-
output_slice=None,
|
|
323
|
-
)
|
|
324
|
-
return pinn, pinn.init_params
|
jinns/utils/_ppinn.py
DELETED
|
@@ -1,227 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Implements utility function to create PINNs
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Callable, Literal
|
|
6
|
-
from dataclasses import InitVar
|
|
7
|
-
import jax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
|
-
import equinox as eqx
|
|
10
|
-
|
|
11
|
-
from jaxtyping import Array, Key, PyTree, Float
|
|
12
|
-
|
|
13
|
-
from jinns.parameters._params import Params, ParamsDict
|
|
14
|
-
|
|
15
|
-
from jinns.utils._pinn import PINN, _MLP
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class PPINN(PINN):
|
|
19
|
-
r"""
|
|
20
|
-
A PPINN (Parallel PINN) object which mimicks the PFNN architecture from
|
|
21
|
-
DeepXDE. This is in fact a PINN that encompasses several PINNs internally.
|
|
22
|
-
|
|
23
|
-
Parameters
|
|
24
|
-
----------
|
|
25
|
-
slice_solution : slice
|
|
26
|
-
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
27
|
-
dedicated to the actual equation solution. Default None
|
|
28
|
-
means that slice_solution = the whole PINN output. This argument is useful
|
|
29
|
-
when the PINN is also used to output equation parameters for example
|
|
30
|
-
Note that it must be a slice and not an integer (a preprocessing of the
|
|
31
|
-
user provided argument takes care of it).
|
|
32
|
-
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
33
|
-
A string with three possibilities.
|
|
34
|
-
"ODE": the PINN is called with one input `t`.
|
|
35
|
-
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
36
|
-
can be high dimensional.
|
|
37
|
-
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
38
|
-
can be high dimensional.
|
|
39
|
-
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
40
|
-
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
41
|
-
after the `input_transform` function.
|
|
42
|
-
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
43
|
-
A function that will be called before entering the PINN. Its output(s)
|
|
44
|
-
must match the PINN inputs (except for the parameters).
|
|
45
|
-
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
46
|
-
and the parameters. Default is no operation.
|
|
47
|
-
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
48
|
-
A function with arguments begin the same input as the PINN, the PINN
|
|
49
|
-
output and the parameter. This function will be called after exiting the PINN.
|
|
50
|
-
Default is no operation.
|
|
51
|
-
mlp_list : list[eqx.Module]
|
|
52
|
-
The actual neural networks instanciated as eqx.Modules
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
slice_solution: slice = eqx.field(static=True, kw_only=True)
|
|
56
|
-
output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
57
|
-
|
|
58
|
-
mlp_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
|
|
59
|
-
|
|
60
|
-
params: PyTree = eqx.field(init=False)
|
|
61
|
-
static: PyTree = eqx.field(init=False, static=True)
|
|
62
|
-
|
|
63
|
-
def __post_init__(self, mlp, mlp_list):
|
|
64
|
-
super().__post_init__(
|
|
65
|
-
mlp=mlp_list[0],
|
|
66
|
-
)
|
|
67
|
-
self.params, self.static = (), ()
|
|
68
|
-
for mlp in mlp_list:
|
|
69
|
-
params, static = eqx.partition(mlp, eqx.is_inexact_array)
|
|
70
|
-
self.params = self.params + (params,)
|
|
71
|
-
self.static = self.static + (static,)
|
|
72
|
-
|
|
73
|
-
@property
|
|
74
|
-
def init_params(self) -> PyTree:
|
|
75
|
-
"""
|
|
76
|
-
Returns an initial set of parameters
|
|
77
|
-
"""
|
|
78
|
-
return self.params
|
|
79
|
-
|
|
80
|
-
def __call__(
|
|
81
|
-
self,
|
|
82
|
-
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
83
|
-
params: PyTree,
|
|
84
|
-
) -> Float[Array, "output_dim"]:
|
|
85
|
-
"""
|
|
86
|
-
Evaluate the PPINN on some inputs with some params.
|
|
87
|
-
"""
|
|
88
|
-
if len(inputs.shape) == 0:
|
|
89
|
-
# This can happen often when the user directly provides some
|
|
90
|
-
# collocation points (eg for plotting, whithout using
|
|
91
|
-
# DataGenerators)
|
|
92
|
-
inputs = inputs[None]
|
|
93
|
-
transformed_inputs = self.input_transform(inputs, params)
|
|
94
|
-
|
|
95
|
-
outs = []
|
|
96
|
-
for params_, static in zip(params.nn_params, self.static):
|
|
97
|
-
model = eqx.combine(params_, static)
|
|
98
|
-
outs += [model(transformed_inputs)]
|
|
99
|
-
# Note that below is then a global output transform
|
|
100
|
-
res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
|
|
101
|
-
|
|
102
|
-
## force (1,) output for non vectorial solution (consistency)
|
|
103
|
-
if not res.shape:
|
|
104
|
-
return jnp.expand_dims(res, axis=-1)
|
|
105
|
-
return res
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def create_PPINN(
|
|
109
|
-
key: Key,
|
|
110
|
-
eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]],
|
|
111
|
-
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
112
|
-
dim_x: int = 0,
|
|
113
|
-
input_transform: Callable[
|
|
114
|
-
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
115
|
-
] = None,
|
|
116
|
-
output_transform: Callable[
|
|
117
|
-
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
118
|
-
Float[Array, "output_dim"],
|
|
119
|
-
] = None,
|
|
120
|
-
slice_solution: slice = None,
|
|
121
|
-
) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
|
|
122
|
-
r"""
|
|
123
|
-
Utility function to create a standard PINN neural network with the equinox
|
|
124
|
-
library.
|
|
125
|
-
|
|
126
|
-
Parameters
|
|
127
|
-
----------
|
|
128
|
-
key
|
|
129
|
-
A JAX random key that will be used to initialize the network
|
|
130
|
-
parameters.
|
|
131
|
-
eqx_list_list
|
|
132
|
-
A list of `eqx_list` (see `create_PINN`). The input dimension must be the
|
|
133
|
-
same for each sub-`eqx_list`. Then the parallel subnetworks can be
|
|
134
|
-
different. Their respective outputs are concatenated.
|
|
135
|
-
eq_type
|
|
136
|
-
A string with three possibilities.
|
|
137
|
-
"ODE": the PPINN is called with one input `t`.
|
|
138
|
-
"statio_PDE": the PPINN is called with one input `x`, `x`
|
|
139
|
-
can be high dimensional.
|
|
140
|
-
"nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
|
|
141
|
-
can be high dimensional.
|
|
142
|
-
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
143
|
-
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
144
|
-
after the `input_transform` function.
|
|
145
|
-
dim_x
|
|
146
|
-
An integer. The dimension of `x`. Default `0`.
|
|
147
|
-
input_transform
|
|
148
|
-
A function that will be called before entering the PPINN. Its output(s)
|
|
149
|
-
must match the PPINN inputs (except for the parameters).
|
|
150
|
-
Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
|
|
151
|
-
and the parameters. Default is no operation.
|
|
152
|
-
output_transform
|
|
153
|
-
This function will be called after exiting
|
|
154
|
-
the PPINN, i.e., on the concatenated outputs of all parallel networks
|
|
155
|
-
Default is no operation.
|
|
156
|
-
slice_solution
|
|
157
|
-
A jnp.s\_ object which indicates which axis of the PPINN output is
|
|
158
|
-
dedicated to the actual equation solution. Default None
|
|
159
|
-
means that slice_solution = the whole PPINN output. This argument is
|
|
160
|
-
useful when the PPINN is also used to output equation parameters for
|
|
161
|
-
example Note that it must be a slice and not an integer (a
|
|
162
|
-
preprocessing of the user provided argument takes care of it).
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
Returns
|
|
166
|
-
-------
|
|
167
|
-
ppinn
|
|
168
|
-
A PPINN instance
|
|
169
|
-
ppinn.init_params
|
|
170
|
-
An initial set of parameters for the PPINN
|
|
171
|
-
|
|
172
|
-
Raises
|
|
173
|
-
------
|
|
174
|
-
RuntimeError
|
|
175
|
-
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
176
|
-
"nonstatio_PDE"]`
|
|
177
|
-
RuntimeError
|
|
178
|
-
If we have a `dim_x > 0` and `eq_type == "ODE"`
|
|
179
|
-
or if we have a `dim_x = 0` and `eq_type != "ODE"`
|
|
180
|
-
"""
|
|
181
|
-
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
182
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
183
|
-
|
|
184
|
-
if eq_type == "ODE" and dim_x != 0:
|
|
185
|
-
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
186
|
-
|
|
187
|
-
if eq_type != "ODE" and dim_x == 0:
|
|
188
|
-
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
189
|
-
|
|
190
|
-
nb_outputs_declared = 0
|
|
191
|
-
for eqx_list in eqx_list_list:
|
|
192
|
-
try:
|
|
193
|
-
nb_outputs_declared += eqx_list[-1][2] # normally we look for 3rd ele of
|
|
194
|
-
# last layer
|
|
195
|
-
except IndexError:
|
|
196
|
-
nb_outputs_declared += eqx_list[-2][2]
|
|
197
|
-
|
|
198
|
-
if slice_solution is None:
|
|
199
|
-
slice_solution = jnp.s_[0:nb_outputs_declared]
|
|
200
|
-
if isinstance(slice_solution, int):
|
|
201
|
-
# rewrite it as a slice to ensure that axis does not disappear when
|
|
202
|
-
# indexing
|
|
203
|
-
slice_solution = jnp.s_[slice_solution : slice_solution + 1]
|
|
204
|
-
|
|
205
|
-
if input_transform is None:
|
|
206
|
-
|
|
207
|
-
def input_transform(_in, _params):
|
|
208
|
-
return _in
|
|
209
|
-
|
|
210
|
-
if output_transform is None:
|
|
211
|
-
|
|
212
|
-
def output_transform(_in_pinn, _out_pinn, _params):
|
|
213
|
-
return _out_pinn
|
|
214
|
-
|
|
215
|
-
mlp_list = []
|
|
216
|
-
for eqx_list in eqx_list_list:
|
|
217
|
-
mlp_list.append(_MLP(key=key, eqx_list=eqx_list))
|
|
218
|
-
|
|
219
|
-
ppinn = PPINN(
|
|
220
|
-
mlp=None,
|
|
221
|
-
mlp_list=mlp_list,
|
|
222
|
-
slice_solution=slice_solution,
|
|
223
|
-
eq_type=eq_type,
|
|
224
|
-
input_transform=input_transform,
|
|
225
|
-
output_transform=output_transform,
|
|
226
|
-
)
|
|
227
|
-
return ppinn, ppinn.init_params
|