jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/utils/_ppinn.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements utility function to create PINNs
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Literal
|
|
6
|
+
from dataclasses import InitVar
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
|
|
11
|
+
from jaxtyping import Array, Key, PyTree, Float
|
|
12
|
+
|
|
13
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
14
|
+
|
|
15
|
+
from jinns.utils._pinn import PINN, _MLP
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PPINN(PINN):
|
|
19
|
+
r"""
|
|
20
|
+
A PPINN (Parallel PINN) object which mimicks the PFNN architecture from
|
|
21
|
+
DeepXDE. This is in fact a PINN that encompasses several PINNs internally.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
slice_solution : slice
|
|
26
|
+
A jnp.s\_ object which indicates which axis of the PINN output is
|
|
27
|
+
dedicated to the actual equation solution. Default None
|
|
28
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
29
|
+
when the PINN is also used to output equation parameters for example
|
|
30
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
31
|
+
user provided argument takes care of it).
|
|
32
|
+
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
33
|
+
A string with three possibilities.
|
|
34
|
+
"ODE": the PINN is called with one input `t`.
|
|
35
|
+
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
36
|
+
can be high dimensional.
|
|
37
|
+
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
38
|
+
can be high dimensional.
|
|
39
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
40
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
41
|
+
after the `input_transform` function.
|
|
42
|
+
input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
|
|
43
|
+
A function that will be called before entering the PINN. Its output(s)
|
|
44
|
+
must match the PINN inputs (except for the parameters).
|
|
45
|
+
Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
|
|
46
|
+
and the parameters. Default is no operation.
|
|
47
|
+
output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
|
|
48
|
+
A function with arguments begin the same input as the PINN, the PINN
|
|
49
|
+
output and the parameter. This function will be called after exiting the PINN.
|
|
50
|
+
Default is no operation.
|
|
51
|
+
mlp_list : list[eqx.Module]
|
|
52
|
+
The actual neural networks instanciated as eqx.Modules
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
slice_solution: slice = eqx.field(static=True, kw_only=True)
|
|
56
|
+
output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
|
|
57
|
+
|
|
58
|
+
mlp_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
|
|
59
|
+
|
|
60
|
+
params: PyTree = eqx.field(init=False)
|
|
61
|
+
static: PyTree = eqx.field(init=False, static=True)
|
|
62
|
+
|
|
63
|
+
def __post_init__(self, mlp, mlp_list):
|
|
64
|
+
super().__post_init__(
|
|
65
|
+
mlp=mlp_list[0],
|
|
66
|
+
)
|
|
67
|
+
self.params, self.static = (), ()
|
|
68
|
+
for mlp in mlp_list:
|
|
69
|
+
params, static = eqx.partition(mlp, eqx.is_inexact_array)
|
|
70
|
+
self.params = self.params + (params,)
|
|
71
|
+
self.static = self.static + (static,)
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def init_params(self) -> PyTree:
|
|
75
|
+
"""
|
|
76
|
+
Returns an initial set of parameters
|
|
77
|
+
"""
|
|
78
|
+
return self.params
|
|
79
|
+
|
|
80
|
+
def __call__(
|
|
81
|
+
self,
|
|
82
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
83
|
+
params: PyTree,
|
|
84
|
+
) -> Float[Array, "output_dim"]:
|
|
85
|
+
"""
|
|
86
|
+
Evaluate the PPINN on some inputs with some params.
|
|
87
|
+
"""
|
|
88
|
+
if len(inputs.shape) == 0:
|
|
89
|
+
# This can happen often when the user directly provides some
|
|
90
|
+
# collocation points (eg for plotting, whithout using
|
|
91
|
+
# DataGenerators)
|
|
92
|
+
inputs = inputs[None]
|
|
93
|
+
transformed_inputs = self.input_transform(inputs, params)
|
|
94
|
+
|
|
95
|
+
outs = []
|
|
96
|
+
for params_, static in zip(params.nn_params, self.static):
|
|
97
|
+
model = eqx.combine(params_, static)
|
|
98
|
+
outs += [model(transformed_inputs)]
|
|
99
|
+
# Note that below is then a global output transform
|
|
100
|
+
res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
|
|
101
|
+
|
|
102
|
+
## force (1,) output for non vectorial solution (consistency)
|
|
103
|
+
if not res.shape:
|
|
104
|
+
return jnp.expand_dims(res, axis=-1)
|
|
105
|
+
return res
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def create_PPINN(
|
|
109
|
+
key: Key,
|
|
110
|
+
eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]],
|
|
111
|
+
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
112
|
+
dim_x: int = 0,
|
|
113
|
+
input_transform: Callable[
|
|
114
|
+
[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
|
|
115
|
+
] = None,
|
|
116
|
+
output_transform: Callable[
|
|
117
|
+
[Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
|
|
118
|
+
Float[Array, "output_dim"],
|
|
119
|
+
] = None,
|
|
120
|
+
slice_solution: slice = None,
|
|
121
|
+
) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
|
|
122
|
+
r"""
|
|
123
|
+
Utility function to create a standard PINN neural network with the equinox
|
|
124
|
+
library.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
key
|
|
129
|
+
A JAX random key that will be used to initialize the network
|
|
130
|
+
parameters.
|
|
131
|
+
eqx_list_list
|
|
132
|
+
A list of `eqx_list` (see `create_PINN`). The input dimension must be the
|
|
133
|
+
same for each sub-`eqx_list`. Then the parallel subnetworks can be
|
|
134
|
+
different. Their respective outputs are concatenated.
|
|
135
|
+
eq_type
|
|
136
|
+
A string with three possibilities.
|
|
137
|
+
"ODE": the PPINN is called with one input `t`.
|
|
138
|
+
"statio_PDE": the PPINN is called with one input `x`, `x`
|
|
139
|
+
can be high dimensional.
|
|
140
|
+
"nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
|
|
141
|
+
can be high dimensional.
|
|
142
|
+
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
143
|
+
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
144
|
+
after the `input_transform` function.
|
|
145
|
+
dim_x
|
|
146
|
+
An integer. The dimension of `x`. Default `0`.
|
|
147
|
+
input_transform
|
|
148
|
+
A function that will be called before entering the PPINN. Its output(s)
|
|
149
|
+
must match the PPINN inputs (except for the parameters).
|
|
150
|
+
Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
|
|
151
|
+
and the parameters. Default is no operation.
|
|
152
|
+
output_transform
|
|
153
|
+
This function will be called after exiting
|
|
154
|
+
the PPINN, i.e., on the concatenated outputs of all parallel networks
|
|
155
|
+
Default is no operation.
|
|
156
|
+
slice_solution
|
|
157
|
+
A jnp.s\_ object which indicates which axis of the PPINN output is
|
|
158
|
+
dedicated to the actual equation solution. Default None
|
|
159
|
+
means that slice_solution = the whole PPINN output. This argument is
|
|
160
|
+
useful when the PPINN is also used to output equation parameters for
|
|
161
|
+
example Note that it must be a slice and not an integer (a
|
|
162
|
+
preprocessing of the user provided argument takes care of it).
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
ppinn
|
|
168
|
+
A PPINN instance
|
|
169
|
+
ppinn.init_params
|
|
170
|
+
An initial set of parameters for the PPINN
|
|
171
|
+
|
|
172
|
+
Raises
|
|
173
|
+
------
|
|
174
|
+
RuntimeError
|
|
175
|
+
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
176
|
+
"nonstatio_PDE"]`
|
|
177
|
+
RuntimeError
|
|
178
|
+
If we have a `dim_x > 0` and `eq_type == "ODE"`
|
|
179
|
+
or if we have a `dim_x = 0` and `eq_type != "ODE"`
|
|
180
|
+
"""
|
|
181
|
+
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
182
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
183
|
+
|
|
184
|
+
if eq_type == "ODE" and dim_x != 0:
|
|
185
|
+
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
186
|
+
|
|
187
|
+
if eq_type != "ODE" and dim_x == 0:
|
|
188
|
+
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
189
|
+
|
|
190
|
+
nb_outputs_declared = 0
|
|
191
|
+
for eqx_list in eqx_list_list:
|
|
192
|
+
try:
|
|
193
|
+
nb_outputs_declared += eqx_list[-1][2] # normally we look for 3rd ele of
|
|
194
|
+
# last layer
|
|
195
|
+
except IndexError:
|
|
196
|
+
nb_outputs_declared += eqx_list[-2][2]
|
|
197
|
+
|
|
198
|
+
if slice_solution is None:
|
|
199
|
+
slice_solution = jnp.s_[0:nb_outputs_declared]
|
|
200
|
+
if isinstance(slice_solution, int):
|
|
201
|
+
# rewrite it as a slice to ensure that axis does not disappear when
|
|
202
|
+
# indexing
|
|
203
|
+
slice_solution = jnp.s_[slice_solution : slice_solution + 1]
|
|
204
|
+
|
|
205
|
+
if input_transform is None:
|
|
206
|
+
|
|
207
|
+
def input_transform(_in, _params):
|
|
208
|
+
return _in
|
|
209
|
+
|
|
210
|
+
if output_transform is None:
|
|
211
|
+
|
|
212
|
+
def output_transform(_in_pinn, _out_pinn, _params):
|
|
213
|
+
return _out_pinn
|
|
214
|
+
|
|
215
|
+
mlp_list = []
|
|
216
|
+
for eqx_list in eqx_list_list:
|
|
217
|
+
mlp_list.append(_MLP(key=key, eqx_list=eqx_list))
|
|
218
|
+
|
|
219
|
+
ppinn = PPINN(
|
|
220
|
+
mlp=None,
|
|
221
|
+
mlp_list=mlp_list,
|
|
222
|
+
slice_solution=slice_solution,
|
|
223
|
+
eq_type=eq_type,
|
|
224
|
+
input_transform=input_transform,
|
|
225
|
+
output_transform=output_transform,
|
|
226
|
+
)
|
|
227
|
+
return ppinn, ppinn.init_params
|
jinns/utils/_save_load.py
CHANGED
|
@@ -20,18 +20,18 @@ def function_to_string(
|
|
|
20
20
|
We need this transformation for eqx_list to be pickled
|
|
21
21
|
|
|
22
22
|
From `((eqx.nn.Linear, 2, 20),
|
|
23
|
-
(jax.nn.tanh),
|
|
23
|
+
(jax.nn.tanh,),
|
|
24
24
|
(eqx.nn.Linear, 20, 20),
|
|
25
|
-
(jax.nn.tanh),
|
|
25
|
+
(jax.nn.tanh,),
|
|
26
26
|
(eqx.nn.Linear, 20, 20),
|
|
27
|
-
(jax.nn.tanh),
|
|
27
|
+
(jax.nn.tanh,),
|
|
28
28
|
(eqx.nn.Linear, 20, 1))` to
|
|
29
29
|
`(("Linear", 2, 20),
|
|
30
|
-
("tanh"),
|
|
30
|
+
("tanh",),
|
|
31
31
|
("Linear", 20, 20),
|
|
32
|
-
("tanh"),
|
|
32
|
+
("tanh",),
|
|
33
33
|
("Linear", 20, 20),
|
|
34
|
-
("tanh"),
|
|
34
|
+
("tanh",),
|
|
35
35
|
("Linear", 20, 1))`
|
|
36
36
|
"""
|
|
37
37
|
return jax.tree_util.tree_map(
|
|
@@ -210,14 +210,16 @@ def load_pinn(
|
|
|
210
210
|
if type_ == "pinn":
|
|
211
211
|
# next line creates a shallow model, the jax arrays are just shapes and
|
|
212
212
|
# not populated, this just recreates the correct pytree structure
|
|
213
|
-
u_reloaded_shallow = eqx.filter_eval_shape(create_PINN, **kwargs_reloaded)
|
|
213
|
+
u_reloaded_shallow, _ = eqx.filter_eval_shape(create_PINN, **kwargs_reloaded)
|
|
214
214
|
elif type_ == "spinn":
|
|
215
|
-
u_reloaded_shallow = eqx.filter_eval_shape(create_SPINN, **kwargs_reloaded)
|
|
215
|
+
u_reloaded_shallow, _ = eqx.filter_eval_shape(create_SPINN, **kwargs_reloaded)
|
|
216
216
|
elif type_ == "hyperpinn":
|
|
217
217
|
kwargs_reloaded["eqx_list_hyper"] = string_to_function(
|
|
218
218
|
kwargs_reloaded["eqx_list_hyper"]
|
|
219
219
|
)
|
|
220
|
-
u_reloaded_shallow = eqx.filter_eval_shape(
|
|
220
|
+
u_reloaded_shallow, _ = eqx.filter_eval_shape(
|
|
221
|
+
create_HYPERPINN, **kwargs_reloaded
|
|
222
|
+
)
|
|
221
223
|
else:
|
|
222
224
|
raise ValueError(f"{type_} is not valid")
|
|
223
225
|
if key_list_for_paramsdict is None:
|
|
@@ -226,15 +228,13 @@ def load_pinn(
|
|
|
226
228
|
u_reloaded = eqx.tree_deserialise_leaves(
|
|
227
229
|
filename + "-module.eqx", u_reloaded_shallow
|
|
228
230
|
)
|
|
229
|
-
params = Params(
|
|
230
|
-
nn_params=u_reloaded.init_params(), eq_params=eq_params_reloaded
|
|
231
|
-
)
|
|
231
|
+
params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
|
|
232
232
|
else:
|
|
233
233
|
nn_params_dict = {}
|
|
234
234
|
for key in key_list_for_paramsdict:
|
|
235
235
|
u_reloaded = eqx.tree_deserialise_leaves(
|
|
236
236
|
filename + f"-module_{key}.eqx", u_reloaded_shallow
|
|
237
237
|
)
|
|
238
|
-
nn_params_dict[key] = u_reloaded.init_params
|
|
238
|
+
nn_params_dict[key] = u_reloaded.init_params
|
|
239
239
|
params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
|
|
240
240
|
return u_reloaded, params
|
jinns/utils/_spinn.py
CHANGED
|
@@ -10,6 +10,8 @@ import jax.numpy as jnp
|
|
|
10
10
|
import equinox as eqx
|
|
11
11
|
from jaxtyping import Key, Array, Float, PyTree
|
|
12
12
|
|
|
13
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
14
|
+
|
|
13
15
|
|
|
14
16
|
class _SPINN(eqx.Module):
|
|
15
17
|
"""
|
|
@@ -21,7 +23,8 @@ class _SPINN(eqx.Module):
|
|
|
21
23
|
key : InitVar[Key]
|
|
22
24
|
A jax random key for the layer initializations.
|
|
23
25
|
d : int
|
|
24
|
-
The number of dimensions to treat separately
|
|
26
|
+
The number of dimensions to treat separately, including time `t` if
|
|
27
|
+
used for non-stationnary equations.
|
|
25
28
|
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
26
29
|
A tuple of tuples of successive equinox modules and activation functions to
|
|
27
30
|
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
@@ -62,15 +65,11 @@ class _SPINN(eqx.Module):
|
|
|
62
65
|
self.separated_mlp.append(self.layers)
|
|
63
66
|
|
|
64
67
|
def __call__(
|
|
65
|
-
self,
|
|
68
|
+
self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
|
|
66
69
|
) -> Float[Array, "d embed_dim*output_dim"]:
|
|
67
|
-
if t is not None:
|
|
68
|
-
dimensions = jnp.concatenate([t, x.flatten()], axis=0)
|
|
69
|
-
else:
|
|
70
|
-
dimensions = jnp.concatenate([x.flatten()], axis=0)
|
|
71
70
|
outputs = []
|
|
72
71
|
for d in range(self.d):
|
|
73
|
-
t_ =
|
|
72
|
+
t_ = inputs[d : d + 1]
|
|
74
73
|
for layer in self.separated_mlp[d]:
|
|
75
74
|
t_ = layer(t_)
|
|
76
75
|
outputs += [t_]
|
|
@@ -82,13 +81,11 @@ class SPINN(eqx.Module):
|
|
|
82
81
|
A SPINN object compatible with the rest of jinns.
|
|
83
82
|
This is typically created with `create_SPINN`.
|
|
84
83
|
|
|
85
|
-
**NOTE**: SPINNs with `t` and `x` as inputs are best used with a
|
|
86
|
-
DataGenerator with `self.cartesian_product=False` for memory consideration
|
|
87
|
-
|
|
88
84
|
Parameters
|
|
89
85
|
----------
|
|
90
86
|
d : int
|
|
91
|
-
The number of dimensions to treat separately
|
|
87
|
+
The number of dimensions to treat separately, including time `t` if
|
|
88
|
+
used for non-stationnary equations.
|
|
92
89
|
|
|
93
90
|
"""
|
|
94
91
|
|
|
@@ -105,42 +102,28 @@ class SPINN(eqx.Module):
|
|
|
105
102
|
def __post_init__(self, spinn_mlp):
|
|
106
103
|
self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
|
|
107
104
|
|
|
105
|
+
@property
|
|
108
106
|
def init_params(self) -> PyTree:
|
|
109
107
|
"""
|
|
110
108
|
Returns an initial set of parameters
|
|
111
109
|
"""
|
|
112
110
|
return self.params
|
|
113
111
|
|
|
114
|
-
def __call__(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
if self.eq_type == "statio_PDE":
|
|
119
|
-
(x, params) = args
|
|
120
|
-
try:
|
|
121
|
-
spinn = eqx.combine(params.nn_params, self.static)
|
|
122
|
-
except (KeyError, AttributeError, TypeError) as e:
|
|
123
|
-
spinn = eqx.combine(params, self.static)
|
|
124
|
-
v_model = jax.vmap(spinn, (0))
|
|
125
|
-
res = v_model(t=None, x=x)
|
|
126
|
-
return self.eval_nn(res)
|
|
127
|
-
if self.eq_type == "nonstatio_PDE":
|
|
128
|
-
(t, x, params) = args
|
|
129
|
-
try:
|
|
130
|
-
spinn = eqx.combine(params.nn_params, self.static)
|
|
131
|
-
except (KeyError, AttributeError, TypeError) as e:
|
|
132
|
-
spinn = eqx.combine(params, self.static)
|
|
133
|
-
v_model = jax.vmap(spinn, ((0, 0)))
|
|
134
|
-
res = v_model(t, x)
|
|
135
|
-
return self.eval_nn(res)
|
|
136
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
137
|
-
|
|
138
|
-
def eval_nn(
|
|
139
|
-
self, res: Float[Array, "d embed_dim*output_dim"]
|
|
112
|
+
def __call__(
|
|
113
|
+
self,
|
|
114
|
+
t_x: Float[Array, "batch_size 1+dim"],
|
|
115
|
+
params: Params | ParamsDict | PyTree,
|
|
140
116
|
) -> Float[Array, "output_dim"]:
|
|
141
117
|
"""
|
|
142
118
|
Evaluate the SPINN on some inputs with some params.
|
|
143
119
|
"""
|
|
120
|
+
try:
|
|
121
|
+
spinn = eqx.combine(params.nn_params, self.static)
|
|
122
|
+
except (KeyError, AttributeError, TypeError) as e:
|
|
123
|
+
spinn = eqx.combine(params, self.static)
|
|
124
|
+
v_model = jax.vmap(spinn)
|
|
125
|
+
res = v_model(t_x)
|
|
126
|
+
|
|
144
127
|
a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
|
|
145
128
|
b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
|
|
146
129
|
res = jnp.stack(
|
|
@@ -170,7 +153,7 @@ def create_SPINN(
|
|
|
170
153
|
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
171
154
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
172
155
|
m: int = 1,
|
|
173
|
-
) -> SPINN:
|
|
156
|
+
) -> tuple[SPINN, PyTree]:
|
|
174
157
|
"""
|
|
175
158
|
Utility function to create a SPINN neural network with the equinox
|
|
176
159
|
library.
|
|
@@ -218,16 +201,14 @@ def create_SPINN(
|
|
|
218
201
|
then sum groups of `r` embedding dimensions to compute each output.
|
|
219
202
|
Default is 1.
|
|
220
203
|
|
|
221
|
-
!!! note
|
|
222
|
-
SPINNs with `t` and `x` as inputs are best used with a
|
|
223
|
-
DataGenerator with `self.cartesian_product=False` for memory
|
|
224
|
-
consideration
|
|
225
204
|
|
|
226
205
|
|
|
227
206
|
Returns
|
|
228
207
|
-------
|
|
229
208
|
spinn
|
|
230
209
|
An instanciated SPINN
|
|
210
|
+
spinn.init_params
|
|
211
|
+
The initial set of parameters of the model
|
|
231
212
|
|
|
232
213
|
Raises
|
|
233
214
|
------
|
|
@@ -265,4 +246,4 @@ def create_SPINN(
|
|
|
265
246
|
spinn_mlp = _SPINN(key=key, d=d, eqx_list=eqx_list)
|
|
266
247
|
spinn = SPINN(spinn_mlp=spinn_mlp, d=d, r=r, eq_type=eq_type, m=m)
|
|
267
248
|
|
|
268
|
-
return spinn
|
|
249
|
+
return spinn, spinn.init_params
|
jinns/utils/_types.py
CHANGED
jinns/utils/_utils.py
CHANGED
|
@@ -2,13 +2,18 @@
|
|
|
2
2
|
Implements various utility functions
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
5
|
+
from math import prod
|
|
6
|
+
import warnings
|
|
8
7
|
import jax
|
|
9
8
|
import jax.numpy as jnp
|
|
10
9
|
from jaxtyping import PyTree, Array
|
|
11
10
|
|
|
11
|
+
from jinns.data._DataGenerators import (
|
|
12
|
+
DataGeneratorODE,
|
|
13
|
+
CubicMeshPDEStatio,
|
|
14
|
+
CubicMeshPDENonStatio,
|
|
15
|
+
)
|
|
16
|
+
|
|
12
17
|
|
|
13
18
|
def _check_nan_in_pytree(pytree: PyTree) -> bool:
|
|
14
19
|
"""
|
|
@@ -33,7 +38,7 @@ def _check_nan_in_pytree(pytree: PyTree) -> bool:
|
|
|
33
38
|
)
|
|
34
39
|
|
|
35
40
|
|
|
36
|
-
def
|
|
41
|
+
def get_grid(in_array: Array) -> Array:
|
|
37
42
|
"""
|
|
38
43
|
From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
|
|
39
44
|
shape (B, B, ...(D times)..., B, D): along the last axis we have the array
|
|
@@ -49,10 +54,14 @@ def _get_grid(in_array: Array) -> Array:
|
|
|
49
54
|
return in_array
|
|
50
55
|
|
|
51
56
|
|
|
52
|
-
def
|
|
57
|
+
def _check_shape_and_type(
|
|
58
|
+
r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
|
|
59
|
+
) -> Array | float:
|
|
53
60
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
61
|
+
Ensures float type and correct shapes for broadcasting when performing a
|
|
62
|
+
binary operation (like -, + or *) between two arrays.
|
|
63
|
+
First array is a custom user (observation data or output of initial/BC
|
|
64
|
+
functions), the expected shape is the same as the PINN's.
|
|
56
65
|
"""
|
|
57
66
|
if isinstance(r, (int, float)):
|
|
58
67
|
# if we have a scalar cast it to float
|
|
@@ -60,9 +69,28 @@ def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
|
|
|
60
69
|
if r.shape == ():
|
|
61
70
|
# if we have a scalar inside a ndarray
|
|
62
71
|
return r.astype(float)
|
|
63
|
-
if r.shape[-1] ==
|
|
64
|
-
#
|
|
72
|
+
if r.shape[-1] == expected_shape[-1]:
|
|
73
|
+
# broadcasting will be OK
|
|
65
74
|
return r.astype(float)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
75
|
+
|
|
76
|
+
if r.shape != expected_shape:
|
|
77
|
+
# Usually, the reshape below adds a missing (1,) final axis to ensure # the PINN output and the other function (initial/boundary condition)
|
|
78
|
+
# have the correct shape, depending on how the user has coded the
|
|
79
|
+
# initial/boundary condition.
|
|
80
|
+
warnings.warn(
|
|
81
|
+
f"[{cause}] Performing operation `{binop}` between arrays"
|
|
82
|
+
f" of different shapes: got {r.shape} for the custom array and"
|
|
83
|
+
f" {expected_shape} for the PINN."
|
|
84
|
+
f" This can cause unexpected and wrong broadcasting."
|
|
85
|
+
f" Reshaping {r.shape} into {expected_shape}. Reshape your"
|
|
86
|
+
f" custom array to math the {expected_shape=} to prevent this"
|
|
87
|
+
f" warning."
|
|
88
|
+
)
|
|
89
|
+
return r.reshape(expected_shape)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _subtract_with_check(
|
|
93
|
+
a: Array | int, b: Array | int, cause: str = ""
|
|
94
|
+
) -> Array | float:
|
|
95
|
+
a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
|
|
96
|
+
return a - b
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: jinns
|
|
3
|
+
Version: 1.2.0
|
|
4
|
+
Summary: Physics Informed Neural Network with JAX
|
|
5
|
+
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
|
+
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
7
|
+
License: Apache License 2.0
|
|
8
|
+
Project-URL: Repository, https://gitlab.com/mia_jinns/jinns
|
|
9
|
+
Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Programming Language :: Python
|
|
13
|
+
Requires-Python: >=3.10
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
License-File: AUTHORS
|
|
17
|
+
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: jax
|
|
19
|
+
Requires-Dist: jaxopt
|
|
20
|
+
Requires-Dist: optax
|
|
21
|
+
Requires-Dist: equinox>0.11.3
|
|
22
|
+
Requires-Dist: jax-tqdm
|
|
23
|
+
Requires-Dist: diffrax
|
|
24
|
+
Requires-Dist: matplotlib
|
|
25
|
+
Provides-Extra: notebook
|
|
26
|
+
Requires-Dist: jupyter; extra == "notebook"
|
|
27
|
+
Requires-Dist: seaborn; extra == "notebook"
|
|
28
|
+
|
|
29
|
+
jinns
|
|
30
|
+
=====
|
|
31
|
+
|
|
32
|
+
 
|
|
33
|
+
|
|
34
|
+
Physics Informed Neural Networks with JAX. **jinns** is developed to estimate solutions of ODE and PDE problems using neural networks, with a strong focus on
|
|
35
|
+
|
|
36
|
+
1. inverse problems: find equation parameters given noisy/indirect observations
|
|
37
|
+
2. meta-modeling: solve for a parametric family of differential equations
|
|
38
|
+
|
|
39
|
+
It can also be used for forward problems and hybrid-modeling.
|
|
40
|
+
|
|
41
|
+
**jinns** specific points:
|
|
42
|
+
|
|
43
|
+
- **jinns uses JAX** - It is directed to JAX users: forward and backward autodiff, vmapping, jitting and more! No reinventing the wheel: it relies on the JAX ecosystem whenever possible, such as [equinox](https://github.com/patrick-kidger/equinox/) for neural networks or [optax](https://optax.readthedocs.io/) for optimization.
|
|
44
|
+
|
|
45
|
+
- **jinns is highly modular** - It gives users maximum control for defining their problems, and extending the package. The maths and computations are visible and not hidden behind layers of code!
|
|
46
|
+
|
|
47
|
+
- **jinns is efficient** - It compares favorably to other existing Python package for PINNs on the [PINNacle benchmarks](https://github.com/i207M/PINNacle/), as demonstrated in the table below. For more details on the benchmarks, checkout the [PINN multi-library benchmark](https://gitlab.com/mia_jinns/pinn-multi-library-benchmark)
|
|
48
|
+
|
|
49
|
+
- Implemented PINN architectures
|
|
50
|
+
- Vanilla Multi-Layer Perceptron popular accross the PINNs litterature.
|
|
51
|
+
|
|
52
|
+
- [Separable PINNs](https://openreview.net/pdf?id=dEySGIcDnI): allows to leverage forward-mode autodiff for computational speed.
|
|
53
|
+
|
|
54
|
+
- [Hyper PINNs](https://arxiv.org/pdf/2111.01008.pdf): useful for meta-modeling
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
- **Get started**: check out our various notebooks on the [documentation](https://mia_jinns.gitlab.io/jinns/index.html).
|
|
58
|
+
|
|
59
|
+
| | jinns | DeepXDE - JAX | DeepXDE - Pytorch | PINA | Nvidia Modulus |
|
|
60
|
+
|---|:---:|:---:|:---:|:---:|:---:|
|
|
61
|
+
| Burgers1D | **445** | 723 | 671 | 1977 | 646 |
|
|
62
|
+
| NS2d-C | **265** | 278 | 441 | 1600 | 275 |
|
|
63
|
+
| PInv | 149 | 218 | *CC* | 1509 | **135** |
|
|
64
|
+
| Diffusion-Reaction-Inv | **284** | *NI* | 3424 | 4061 | 2541 |
|
|
65
|
+
| Navier-Stokes-Inv | **175** | *NI* | 1511 | 1403 | 498 |
|
|
66
|
+
|
|
67
|
+
*Training time in seconds on an Nvidia T600 GPU. NI means problem cannot be implemented in the backend, CC means the code crashed.*
|
|
68
|
+
|
|
69
|
+

|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Installation
|
|
73
|
+
|
|
74
|
+
Install the latest version with pip
|
|
75
|
+
|
|
76
|
+
```bash
|
|
77
|
+
pip install jinns
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
# Documentation
|
|
81
|
+
|
|
82
|
+
The project's documentation is hosted on Gitlab page and available at [https://mia_jinns.gitlab.io/jinns/index.html](https://mia_jinns.gitlab.io/jinns/index.html).
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Found a bug / want a feature ?
|
|
86
|
+
|
|
87
|
+
Open an issue on the [Gitlab repo](https://gitlab.com/mia_jinns/jinns/-/issues).
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Contributing
|
|
91
|
+
|
|
92
|
+
Here are the contributors guidelines:
|
|
93
|
+
|
|
94
|
+
1. First fork the library on Gitlab.
|
|
95
|
+
|
|
96
|
+
2. Then clone and install the library in development mode with
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
pip install -e .
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
3. Install pre-commit and run it.
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
pip install pre-commit
|
|
106
|
+
pre-commit install
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
4. Open a merge request once you are done with your changes, the review will be done via Gitlab.
|
|
110
|
+
|
|
111
|
+
# Contributors
|
|
112
|
+
|
|
113
|
+
Don't hesitate to contribute and get your name on the list here !
|
|
114
|
+
|
|
115
|
+
**List of contributors:** Hugo Gangloff, Nicolas Jouvin
|
|
116
|
+
|
|
117
|
+
# Cite us
|
|
118
|
+
|
|
119
|
+
Please consider citing our work if you found it useful to yours, using the following lines
|
|
120
|
+
```
|
|
121
|
+
@software{jinns2024,
|
|
122
|
+
title={\texttt{jinns}: Physics-Informed Neural Networks with JAX},
|
|
123
|
+
author={Gangloff, Hugo and Jouvin, Nicolas},
|
|
124
|
+
url={https://gitlab.com/mia_jinns},
|
|
125
|
+
year={2024}
|
|
126
|
+
}
|
|
127
|
+
```
|