jinns 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/utils/_pinn.py ADDED
@@ -0,0 +1,308 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import equinox as eqx
4
+
5
+
6
+ class _MLP(eqx.Module):
7
+ """
8
+ Class to construct an equinox module from a key and a eqx_list. To be used
9
+ in pair with the function `create_PINN`
10
+ """
11
+
12
+ layers: list
13
+
14
+ def __init__(self, key, eqx_list):
15
+ """
16
+ Parameters
17
+ ----------
18
+ key
19
+ A jax random key
20
+ eqx_list
21
+ A list of list of successive equinox modules and activation functions to
22
+ describe the PINN architecture. The inner lists have the eqx module or
23
+ axtivation function as first item, other items represents arguments
24
+ that could be required (eg. the size of the layer).
25
+ __Note:__ the `key` argument need not be given.
26
+ Thus typical example is `eqx_list=
27
+ [[eqx.nn.Linear, 2, 20],
28
+ [jax.nn.tanh],
29
+ [eqx.nn.Linear, 20, 20],
30
+ [jax.nn.tanh],
31
+ [eqx.nn.Linear, 20, 20],
32
+ [jax.nn.tanh],
33
+ [eqx.nn.Linear, 20, 1]
34
+ ]`
35
+ """
36
+
37
+ self.layers = []
38
+ # TODO we are limited currently in the number of layer type we can
39
+ # parse and we lack some safety checks
40
+ for l in eqx_list:
41
+ if len(l) == 1:
42
+ self.layers.append(l[0])
43
+ else:
44
+ # By default we append a random key at the end of the
45
+ # arguments fed into a layer module call
46
+ key, subkey = jax.random.split(key, 2)
47
+ # the argument key is keyword only
48
+ self.layers.append(l[0](*l[1:], key=subkey))
49
+
50
+ def __call__(self, t):
51
+ for layer in self.layers:
52
+ t = layer(t)
53
+ return t
54
+
55
+
56
+ class PINN:
57
+ """
58
+ Basically a wrapper around the `__call__` function to be able to give a type to
59
+ our former `self.u`
60
+ The function create_PINN has the role to population the `__call__` function
61
+ """
62
+
63
+ def __init__(self, key, eqx_list, output_slice=None):
64
+ _pinn = _MLP(key, eqx_list)
65
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
66
+ self.output_slice = output_slice
67
+
68
+ def init_params(self):
69
+ return self.params
70
+
71
+ def __call__(self, *args, **kwargs):
72
+ return self.apply_fn(self, *args, **kwargs)
73
+
74
+ def _eval_nn(self, inputs, u_params, eq_params, input_transform, output_transform):
75
+ """
76
+ inner function to factorize code. apply_fn (which takes varying forms)
77
+ call _eval_nn which always have the same content.
78
+ """
79
+ model = eqx.combine(u_params, self.static)
80
+ res = output_transform(inputs, model(input_transform(inputs)).squeeze())
81
+
82
+ if self.output_slice is not None:
83
+ res = res[self.output_slice]
84
+
85
+ ## force (1,) output for non vectorial solution (consistency)
86
+ if not res.shape:
87
+ return jnp.expand_dims(res, axis=-1)
88
+ else:
89
+ return res
90
+
91
+
92
+ def create_PINN(
93
+ key,
94
+ eqx_list,
95
+ eq_type,
96
+ dim_x=0,
97
+ with_eq_params=None,
98
+ input_transform=None,
99
+ output_transform=None,
100
+ shared_pinn_outputs=None,
101
+ ):
102
+ """
103
+ Utility function to create a standard PINN neural network with the equinox
104
+ library.
105
+
106
+ Parameters
107
+ ----------
108
+ key
109
+ A jax random key that will be used to initialize the network parameters
110
+ eqx_list
111
+ A list of list of successive equinox modules and activation functions to
112
+ describe the PINN architecture. The inner lists have the eqx module or
113
+ axtivation function as first item, other items represents arguments
114
+ that could be required (eg. the size of the layer).
115
+ __Note:__ the `key` argument need not be given.
116
+ Thus typical example is `eqx_list=
117
+ [[eqx.nn.Linear, 2, 20],
118
+ [jax.nn.tanh],
119
+ [eqx.nn.Linear, 20, 20],
120
+ [jax.nn.tanh],
121
+ [eqx.nn.Linear, 20, 20],
122
+ [jax.nn.tanh],
123
+ [eqx.nn.Linear, 20, 1]
124
+ ]`
125
+ eq_type
126
+ A string with three possibilities.
127
+ "ODE": the PINN is called with one input `t`.
128
+ "statio_PDE": the PINN is called with one input `x`, `x`
129
+ can be high dimensional.
130
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
131
+ can be high dimensional.
132
+ **Note: the input dimension as given in eqx_list has to match the sum
133
+ of the dimension of `t` + the dimension of `x` + the number of
134
+ parameters in `eq_params` if with_eq_params is `True` (see below)**
135
+ dim_x
136
+ An integer. The dimension of `x`. Default `0`
137
+ with_eq_params
138
+ Default is None. Otherwise a list of keys from the dict `eq_params`
139
+ that the network also takes as inputs.
140
+ the equation parameters (`eq_params`).
141
+ **If some keys are provided, the input dimension
142
+ as given in eqx_list must take into account the number of such provided
143
+ keys (i.e., the input dimension is the addition of the dimension of ``t``
144
+ + the dimension of ``x`` + the number of ``eq_params``)**
145
+ input_transform
146
+ A function that will be called before entering the PINN. Its output(s)
147
+ must match the PINN inputs. Default is the No operation
148
+ output_transform
149
+ A function with arguments the same input(s) as the PINN AND the PINN
150
+ output that will be called after exiting the PINN. Default is the No
151
+ operation
152
+ shared_pinn_outputs
153
+ A tuple of jnp.s_[] (slices) to determine the different output for each
154
+ network. In this case we return a list of PINNs, one for each output in
155
+ shared_pinn_outputs. This is useful to create PINNs that share the
156
+ same network and same parameters. Default is None, we only return one PINN.
157
+
158
+
159
+ Returns
160
+ -------
161
+ init_fn
162
+ A function which (re-)initializes the PINN parameters with the provided
163
+ jax random key
164
+ apply_fn
165
+ A function to apply the neural network on given inputs for given
166
+ parameters. A typical call will be of the form `u(t, nn_params)` for
167
+ ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
168
+ or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
169
+
170
+ Raises
171
+ ------
172
+ RuntimeError
173
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
174
+ "nonstatio_PDE"]`
175
+ RuntimeError
176
+ If we have a `dim_x > 0` and `eq_type == "ODE"`
177
+ or if we have a `dim_x = 0` and `eq_type != "ODE"`
178
+ """
179
+ if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
180
+ raise RuntimeError("Wrong parameter value for eq_type")
181
+
182
+ if eq_type == "ODE" and dim_x != 0:
183
+ raise RuntimeError("Wrong parameter combination eq_type and dim_x")
184
+
185
+ if eq_type != "ODE" and dim_x == 0:
186
+ raise RuntimeError("Wrong parameter combination eq_type and dim_x")
187
+
188
+ dim_t = 0 if eq_type == "statio_PDE" else 1
189
+ dim_in_params = len(with_eq_params) if with_eq_params is not None else 0
190
+ try:
191
+ nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
192
+ except IndexError:
193
+ nb_inputs_declared = eqx_list[1][1]
194
+ # but we can have, eg, a flatten first layer
195
+
196
+ try:
197
+ nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
198
+ # last layer
199
+ except IndexError:
200
+ nb_outputs_declared = eqx_list[-2][2]
201
+ # but we can have, eg, a `jnp.exp` last layer
202
+
203
+ # NOTE Currently the check below is disabled because we added
204
+ # input_transform
205
+ # if dim_t + dim_x + dim_in_params != nb_inputs_declared:
206
+ # raise RuntimeError("Error in the declarations of the number of parameters")
207
+
208
+ if input_transform is None:
209
+
210
+ def input_transform(_in):
211
+ return _in
212
+
213
+ if output_transform is None:
214
+
215
+ def output_transform(_in_pinn, _out_pinn):
216
+ return _out_pinn
217
+
218
+ if eq_type == "ODE":
219
+ if with_eq_params is None:
220
+
221
+ def apply_fn(self, t, u_params, eq_params=None):
222
+ t = t[
223
+ None
224
+ ] # Note that we added a dimension to t which is lacking for the ODE batches
225
+ return self._eval_nn(
226
+ t, u_params, eq_params, input_transform, output_transform
227
+ ).squeeze()
228
+
229
+ else:
230
+
231
+ def apply_fn(self, t, u_params, eq_params):
232
+ t = t[
233
+ None
234
+ ] # We added a dimension to t which is lacking for the ODE batches
235
+ eq_params_flatten = jnp.concatenate(
236
+ [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
237
+ )
238
+ t_eq_params = jnp.concatenate([t, eq_params_flatten], axis=-1)
239
+ return self._eval_nn(
240
+ t_eq_params, u_params, eq_params, input_transform, output_transform
241
+ )
242
+
243
+ elif eq_type == "statio_PDE":
244
+ # Here we add an argument `x` which can be high dimensional
245
+ if with_eq_params is None:
246
+
247
+ def apply_fn(self, x, u_params, eq_params=None):
248
+ return self._eval_nn(
249
+ x, u_params, eq_params, input_transform, output_transform
250
+ )
251
+
252
+ else:
253
+
254
+ def apply_fn(self, x, u_params, eq_params):
255
+ eq_params_flatten = jnp.concatenate(
256
+ [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
257
+ )
258
+ x_eq_params = jnp.concatenate([x, eq_params_flatten], axis=-1)
259
+ return self._eval_nn(
260
+ x_eq_params, u_params, eq_params, input_transform, output_transform
261
+ )
262
+
263
+ elif eq_type == "nonstatio_PDE":
264
+ # Here we add an argument `x` which can be high dimensional
265
+ if with_eq_params is None:
266
+
267
+ def apply_fn(self, t, x, u_params, eq_params=None):
268
+ t_x = jnp.concatenate([t, x], axis=-1)
269
+ return self._eval_nn(
270
+ t_x, u_params, eq_params, input_transform, output_transform
271
+ )
272
+
273
+ else:
274
+
275
+ def apply_fn(self, t, x, u_params, eq_params):
276
+ t_x = jnp.concatenate([t, x], axis=-1)
277
+ eq_params_flatten = jnp.concatenate(
278
+ [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
279
+ )
280
+ t_x_eq_params = jnp.concatenate([t_x, eq_params_flatten], axis=-1)
281
+ return self._eval_nn(
282
+ t_x_eq_params,
283
+ u_params,
284
+ eq_params,
285
+ input_transform,
286
+ output_transform,
287
+ )
288
+
289
+ else:
290
+ raise RuntimeError("Wrong parameter value for eq_type")
291
+
292
+ if shared_pinn_outputs is not None:
293
+ pinns = []
294
+ static = None
295
+ for output_slice in shared_pinn_outputs:
296
+ pinn = PINN(key, eqx_list, output_slice)
297
+ pinn.apply_fn = apply_fn
298
+ # all the pinns are in fact the same so we share the same static
299
+ if static is None:
300
+ static = pinn.static
301
+ else:
302
+ pinn.static = static
303
+ pinns.append(pinn)
304
+ return pinns
305
+ else:
306
+ pinn = PINN(key, eqx_list)
307
+ pinn.apply_fn = apply_fn
308
+ return pinn
jinns/utils/_spinn.py ADDED
@@ -0,0 +1,237 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import equinox as eqx
4
+
5
+
6
+ class _SPINN(eqx.Module):
7
+ """
8
+ Construct a Separable PINN as proposed in
9
+ Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
10
+ """
11
+
12
+ layers: list
13
+ separated_mlp: list
14
+ d: int
15
+ r: int
16
+ m: int
17
+
18
+ def __init__(self, key, d, r, eqx_list, m=1):
19
+ """
20
+ Parameters
21
+ ----------
22
+ key
23
+ A jax random key
24
+ d
25
+ An integer. The number of dimensions to treat separately
26
+ r
27
+ An integer. The dimension of the embedding
28
+ eqx_list
29
+ A list of list of successive equinox modules and activation functions to
30
+ describe *each separable PINN architecture*.
31
+ The inner lists have the eqx module or
32
+ axtivation function as first item, other items represents arguments
33
+ that could be required (eg. the size of the layer).
34
+ __Note:__ the `key` argument need not be given.
35
+ Thus typical example is `eqx_list=
36
+ [[eqx.nn.Linear, d, 20],
37
+ [jax.nn.tanh],
38
+ [eqx.nn.Linear, 20, 20],
39
+ [jax.nn.tanh],
40
+ [eqx.nn.Linear, 20, 20],
41
+ [jax.nn.tanh],
42
+ [eqx.nn.Linear, 20, r]
43
+ ]`
44
+ """
45
+ keys = jax.random.split(key, 8)
46
+
47
+ self.d = d
48
+ self.r = r
49
+ self.m = m
50
+
51
+ self.separated_mlp = []
52
+ for d in range(self.d):
53
+ self.layers = []
54
+ for l in eqx_list:
55
+ if len(l) == 1:
56
+ self.layers.append(l[0])
57
+ else:
58
+ key, subkey = jax.random.split(key, 2)
59
+ self.layers.append(l[0](*l[1:], key=subkey))
60
+ self.separated_mlp.append(self.layers)
61
+
62
+ def __call__(self, t, x):
63
+ if t is not None:
64
+ dimensions = jnp.concatenate([t, x.flatten()], axis=0)
65
+ else:
66
+ dimensions = jnp.concatenate([x.flatten()], axis=0)
67
+ outputs = []
68
+ for d in range(self.d):
69
+ t_ = dimensions[d][None]
70
+ for layer in self.separated_mlp[d]:
71
+ t_ = layer(t_)
72
+ outputs += [t_]
73
+ return jnp.asarray(outputs)
74
+
75
+
76
+ class SPINN:
77
+ """
78
+ Basically a wrapper around the `__call__` function to be able to give a type to
79
+ our former `self.u`
80
+ The function create_SPINN has the role to population the `__call__` function
81
+ """
82
+
83
+ def __init__(self, key, d, r, eqx_list, m=1):
84
+ self.d, self.r, self.m = d, r, m
85
+ _spinn = _SPINN(key, d, r, eqx_list, m)
86
+ self.params, self.static = eqx.partition(_spinn, eqx.is_inexact_array)
87
+
88
+ def init_params(self):
89
+ return self.params
90
+
91
+ def __call__(self, *args, **kwargs):
92
+ return self.apply_fn(self, *args, **kwargs)
93
+
94
+ def _eval_nn(self, res):
95
+ """
96
+ common content of apply_fn put here in order to factorize code
97
+ """
98
+ a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
99
+ b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
100
+ res = jnp.stack(
101
+ [
102
+ jnp.einsum(
103
+ f"{a} -> {b}",
104
+ *(
105
+ res[:, d, m * self.r : (m + 1) * self.r]
106
+ for d in range(res.shape[1])
107
+ ),
108
+ )
109
+ for m in range(self.m)
110
+ ],
111
+ axis=-1,
112
+ ) # compute each output dimension
113
+
114
+ # force (1,) output for non vectorial solution (consistency)
115
+ if len(res.shape) == self.d:
116
+ return jnp.expand_dims(res, axis=-1)
117
+ else:
118
+ return res
119
+
120
+
121
+ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
122
+ """
123
+ Utility function to create a SPINN neural network with the equinox
124
+ library.
125
+
126
+ *Note* that a SPINN is not vmapped from the outside and expects batch of the
127
+ same size for each input. It outputs directly a solution of shape
128
+ (batchsize, batchsize). See the paper for more details.
129
+
130
+ Parameters
131
+ ----------
132
+ key
133
+ A jax random key that will be used to initialize the network parameters
134
+ d
135
+ An integer. The number of dimensions to treat separately
136
+ r
137
+ An integer. The dimension of the embedding
138
+ eqx_list
139
+ A list of list of successive equinox modules and activation functions to
140
+ describe *each separable PINN architecture*.
141
+ The inner lists have the eqx module or
142
+ axtivation function as first item, other items represents arguments
143
+ that could be required (eg. the size of the layer).
144
+ __Note:__ the `key` argument need not be given.
145
+ Thus typical example is `eqx_list=
146
+ [[eqx.nn.Linear, d, 20],
147
+ [jax.nn.tanh],
148
+ [eqx.nn.Linear, 20, 20],
149
+ [jax.nn.tanh],
150
+ [eqx.nn.Linear, 20, 20],
151
+ [jax.nn.tanh],
152
+ [eqx.nn.Linear, 20, r]
153
+ ]`
154
+ eq_type
155
+ A string with three possibilities.
156
+ "ODE": the PINN is called with one input `t`.
157
+ "statio_PDE": the PINN is called with one input `x`, `x`
158
+ can be high dimensional.
159
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
160
+ can be high dimensional.
161
+ **Note: the input dimension as given in eqx_list has to match the sum
162
+ of the dimension of `t` + the dimension of `x` + the number of
163
+ parameters in `eq_params` if with_eq_params is `True` (see below)**
164
+ m
165
+ An integer. The output dimension of the neural network. According to
166
+ the SPINN article, a total embedding dimension of `r*m` is defined. We
167
+ then sum groups of `r` embedding dimensions to compute each output.
168
+ Default is 1.
169
+
170
+
171
+ Returns
172
+ -------
173
+ init_fn
174
+ A function which (re-)initializes the SPINN parameters with the provided
175
+ jax random key
176
+ apply_fn
177
+ A function to apply the neural network on given inputs for given
178
+ parameters. A typical call will be of the form `u(t, nn_params)` for
179
+ ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
180
+ or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
181
+
182
+ Raises
183
+ ------
184
+ RuntimeError
185
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
186
+ "nonstatio_PDE"]` and for various failing checks
187
+ """
188
+
189
+ if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
190
+ raise RuntimeError("Wrong parameter value for eq_type")
191
+
192
+ try:
193
+ nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
194
+ except IndexError:
195
+ nb_inputs_declared = eqx_list[1][
196
+ 1
197
+ ] # but we can have, eg, a flatten first layer
198
+ if nb_inputs_declared != 1:
199
+ raise ValueError("Input dim must be set to 1 in SPINN!")
200
+
201
+ try:
202
+ nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
203
+ # last layer
204
+ except IndexError:
205
+ nb_outputs_declared = eqx_list[-2][2]
206
+ # but we can have, eg, a `jnp.exp` last layer
207
+ if nb_outputs_declared != r * m:
208
+ raise ValueError("Output dim must be set to r * m in SPINN!")
209
+
210
+ if d > 24:
211
+ raise ValueError(
212
+ "Too many dimensions, not enough letters" " available in jnp.einsum"
213
+ )
214
+
215
+ if eq_type == "statio_PDE":
216
+
217
+ def apply_fn(self, x, u_params, eq_params=None):
218
+ spinn = eqx.combine(u_params, self.static)
219
+ v_model = jax.vmap(spinn, (0))
220
+ res = v_model(t=None, x=x)
221
+ return self._eval_nn(res)
222
+
223
+ elif eq_type == "nonstatio_PDE":
224
+
225
+ def apply_fn(self, t, x, u_params, eq_params=None):
226
+ spinn = eqx.combine(u_params, self.static)
227
+ v_model = jax.vmap(spinn, ((0, 0)))
228
+ res = v_model(t, x)
229
+ return self._eval_nn(res)
230
+
231
+ else:
232
+ raise RuntimeError("Wrong parameter value for eq_type")
233
+
234
+ spinn = SPINN(key, d, r, eqx_list, m)
235
+ spinn.apply_fn = apply_fn
236
+
237
+ return spinn