jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/utils/_hyperpinn.py DELETED
@@ -1,410 +0,0 @@
1
- """
2
- Implements utility function to create HYPERPINNs
3
- https://arxiv.org/pdf/2111.01008.pdf
4
- """
5
-
6
- import warnings
7
- from dataclasses import InitVar
8
- from typing import Callable, Literal
9
- import copy
10
- from math import prod
11
- import jax
12
- import jax.numpy as jnp
13
- from jax.tree_util import tree_leaves, tree_map
14
- from jaxtyping import Array, Float, PyTree, Int, Key
15
- import equinox as eqx
16
- import numpy as onp
17
-
18
- from jinns.utils._pinn import PINN, _MLP
19
- from jinns.parameters._params import Params
20
-
21
-
22
- def _get_param_nb(
23
- params: Params,
24
- ) -> tuple[Int[onp.ndarray, "1"], Int[onp.ndarray, "n_layers"]]:
25
- """Returns the number of parameters in a Params object and also
26
- the cumulative sum when parsing the object.
27
-
28
-
29
- Parameters
30
- ----------
31
- params :
32
- A Params object.
33
- """
34
- dim_prod_all_arrays = [
35
- prod(a.shape)
36
- for a in tree_leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
37
- ]
38
- return onp.asarray(sum(dim_prod_all_arrays)), onp.cumsum(dim_prod_all_arrays)
39
-
40
-
41
- class HYPERPINN(PINN):
42
- """
43
- A HYPERPINN object compatible with the rest of jinns.
44
- Composed of a PINN and an HYPER network. The HYPERPINN is typically
45
- instanciated using with `create_HYPERPINN`. However, a user could directly
46
- creates their HYPERPINN using this
47
- class by passing an eqx.Module for argument `mlp` (resp. for argument
48
- `hyper_mlp`) that plays the role of the NN (resp. hyper NN) and that is
49
- already instanciated.
50
-
51
- Parameters
52
- ----------
53
- hyperparams: list = eqx.field(static=True)
54
- A list of keys from Params.eq_params that will be considered as
55
- hyperparameters for metamodeling.
56
- hypernet_input_size: int
57
- An integer. The input size of the MLP used for the hypernetwork. Must
58
- be equal to the flattened concatenations for the array of parameters
59
- designated by the `hyperparams` argument.
60
- slice_solution : slice
61
- A jnp.s\_ object which indicates which axis of the PINN output is
62
- dedicated to the actual equation solution. Default None
63
- means that slice_solution = the whole PINN output. This argument is useful
64
- when the PINN is also used to output equation parameters for example
65
- Note that it must be a slice and not an integer (a preprocessing of the
66
- user provided argument takes care of it).
67
- eq_type : str
68
- A string with three possibilities.
69
- "ODE": the HYPERPINN is called with one input `t`.
70
- "statio_PDE": the HYPERPINN is called with one input `x`, `x`
71
- can be high dimensional.
72
- "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
73
- can be high dimensional.
74
- **Note**: the input dimension as given in eqx_list has to match the sum
75
- of the dimension of `t` + the dimension of `x` or the output dimension
76
- after the `input_transform` function
77
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
78
- A function that will be called before entering the PINN. Its output(s)
79
- must match the PINN inputs (except for the parameters).
80
- Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
81
- and the parameters. Default is no operation.
82
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
83
- A function with arguments begin the same input as the PINN, the PINN
84
- output and the parameter. This function will be called after exiting the PINN.
85
- Default is no operation.
86
- output_slice : slice, default=None
87
- A jnp.s\_[] to determine the different dimension for the HYPERPINN.
88
- See `shared_pinn_outputs` argument of `create_HYPERPINN`.
89
- mlp : eqx.Module
90
- The actual neural network instanciated as an eqx.Module.
91
- hyper_mlp : eqx.Module
92
- The actual hyper neural network instanciated as an eqx.Module.
93
- """
94
-
95
- hyperparams: list[str] = eqx.field(static=True, kw_only=True)
96
- hypernet_input_size: int = eqx.field(kw_only=True)
97
-
98
- hyper_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
99
- mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
100
-
101
- params_hyper: PyTree = eqx.field(init=False)
102
- static_hyper: PyTree = eqx.field(init=False, static=True)
103
- pinn_params_sum: Int[onp.ndarray, "1"] = eqx.field(init=False, static=True)
104
- pinn_params_cumsum: Int[onp.ndarray, "n_layers"] = eqx.field(
105
- init=False, static=True
106
- )
107
-
108
- def __post_init__(self, mlp, hyper_mlp):
109
- super().__post_init__(
110
- mlp,
111
- )
112
- self.params_hyper, self.static_hyper = eqx.partition(
113
- hyper_mlp, eqx.is_inexact_array
114
- )
115
- self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
116
-
117
- def init_params(self) -> Params:
118
- """
119
- Returns an initial set of parameters
120
- """
121
- return self.params_hyper
122
-
123
- def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
124
- """
125
- From the output of the hypernetwork we set the well formed
126
- parameters of the pinn (`self.params`)
127
- """
128
- pinn_params_flat = eqx.tree_at(
129
- lambda p: tree_leaves(p, is_leaf=eqx.is_array),
130
- self.params,
131
- jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
132
- )
133
-
134
- return tree_map(
135
- lambda a, b: a.reshape(b.shape),
136
- pinn_params_flat,
137
- self.params,
138
- is_leaf=lambda x: isinstance(x, jnp.ndarray),
139
- )
140
-
141
- def eval_nn(
142
- self,
143
- inputs: Float[Array, "input_dim"],
144
- params: Params | PyTree,
145
- ) -> Float[Array, "output_dim"]:
146
- """
147
- Evaluate the HYPERPINN on some inputs with some params.
148
- """
149
- try:
150
- hyper = eqx.combine(params.nn_params, self.static_hyper)
151
- except (KeyError, AttributeError, TypeError) as e: # give more flexibility
152
- hyper = eqx.combine(params, self.static_hyper)
153
-
154
- eq_params_batch = jnp.concatenate(
155
- [params.eq_params[k].flatten() for k in self.hyperparams], axis=0
156
- )
157
-
158
- hyper_output = hyper(eq_params_batch)
159
-
160
- pinn_params = self._hyper_to_pinn(hyper_output)
161
-
162
- pinn = eqx.combine(pinn_params, self.static)
163
- res = self.output_transform(
164
- inputs, pinn(self.input_transform(inputs, params)).squeeze(), params
165
- )
166
-
167
- if self.output_slice is not None:
168
- res = res[self.output_slice]
169
-
170
- ## force (1,) output for non vectorial solution (consistency)
171
- if not res.shape:
172
- return jnp.expand_dims(res, axis=-1)
173
- return res
174
-
175
-
176
- def create_HYPERPINN(
177
- key: Key,
178
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
179
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
180
- hyperparams: list[str],
181
- hypernet_input_size: int,
182
- dim_x: int = 0,
183
- input_transform: Callable[
184
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
185
- ] = None,
186
- output_transform: Callable[
187
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
188
- Float[Array, "output_dim"],
189
- ] = None,
190
- slice_solution: slice = None,
191
- shared_pinn_outputs: slice = None,
192
- eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
193
- ) -> HYPERPINN | list[HYPERPINN]:
194
- r"""
195
- Utility function to create a standard PINN neural network with the equinox
196
- library.
197
-
198
- Parameters
199
- ----------
200
- key
201
- A JAX random key that will be used to initialize the network
202
- parameters.
203
- eqx_list
204
- A tuple of tuples of successive equinox modules and activation functions to
205
- describe the PINN architecture. The inner tuples must have the eqx module or
206
- activation function as first item, other items represent arguments
207
- that could be required (eg. the size of the layer).
208
- The `key` argument need not be given.
209
- Thus typical example is `eqx_list=
210
- ((eqx.nn.Linear, 2, 20),
211
- jax.nn.tanh,
212
- (eqx.nn.Linear, 20, 20),
213
- jax.nn.tanh,
214
- (eqx.nn.Linear, 20, 20),
215
- jax.nn.tanh,
216
- (eqx.nn.Linear, 20, 1)
217
- )`.
218
- eq_type
219
- A string with three possibilities.
220
- "ODE": the HYPERPINN is called with one input `t`.
221
- "statio_PDE": the HYPERPINN is called with one input `x`, `x`
222
- can be high dimensional.
223
- "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
224
- can be high dimensional.
225
- **Note**: the input dimension as given in eqx_list has to match the sum
226
- of the dimension of `t` + the dimension of `x` or the output dimension
227
- after the `input_transform` function
228
- hyperparams
229
- A list of keys from Params.eq_params that will be considered as
230
- hyperparameters for metamodeling.
231
- hypernet_input_size
232
- An integer. The input size of the MLP used for the hypernetwork. Must
233
- be equal to the flattened concatenations for the array of parameters
234
- designated by the `hyperparams` argument.
235
- dim_x
236
- An integer. The dimension of `x`. Default `0`.
237
- input_transform
238
- A function that will be called before entering the PINN. Its output(s)
239
- must match the PINN inputs (except for the parameters).
240
- Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
241
- and the parameters. Default is no operation.
242
- output_transform
243
- A function with arguments begin the same input as the PINN, the PINN
244
- output and the parameter. This function will be called after exiting the PINN.
245
- Default is no operation.
246
- slice_solution
247
- A jnp.s\_ object which indicates which axis of the PINN output is
248
- dedicated to the actual equation solution. Default None
249
- means that slice_solution = the whole PINN output. This argument is useful
250
- when the PINN is also used to output equation parameters for example
251
- Note that it must be a slice and not an integer (a preprocessing of the
252
- user provided argument takes care of it).
253
- shared_pinn_outputs
254
- Default is None, for a stantard PINN.
255
- A tuple of jnp.s\_[] (slices) to determine the different output for each
256
- network. In this case we return a list of PINNs, one for each output in
257
- shared_pinn_outputs. This is useful to create PINNs that share the
258
- same network and same parameters; **the user must then use the same
259
- parameter set in their manipulation**.
260
- See the notebook 2D Navier Stokes in pipeflow with metamodel for an
261
- example using this option.
262
- eqx_list_hyper
263
- Same as eqx_list but for the hypernetwork. Default is None, i.e., we
264
- use the same architecture as the PINN, up to the number of inputs and
265
- ouputs. Note that the number of inputs must be of the hypernetwork must
266
- be equal to the flattened concatenations for the array of parameters
267
- designated by the `hyperparams` argument;
268
- and the number of outputs must be equal to the number
269
- of parameters in the pinn network
270
-
271
- Returns
272
- -------
273
- hyperpinn
274
- A HYPERPINN instance or, when `shared_pinn_ouput` is not None,
275
- a list of HYPERPINN instances with the same structure is returned,
276
- only differing by there final slicing of the network output.
277
-
278
-
279
- Raises
280
- ------
281
- RuntimeError
282
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
283
- "nonstatio_PDE"]`
284
- RuntimeError
285
- If we have a `dim_x > 0` and `eq_type == "ODE"`
286
- or if we have a `dim_x = 0` and `eq_type != "ODE"`
287
- """
288
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
289
- raise RuntimeError("Wrong parameter value for eq_type")
290
-
291
- if eq_type == "ODE" and dim_x != 0:
292
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
293
-
294
- if eq_type != "ODE" and dim_x == 0:
295
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
296
-
297
- if eqx_list_hyper is None:
298
- eqx_list_hyper = copy.deepcopy(eqx_list)
299
-
300
- try:
301
- nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
302
- # last layer
303
- except IndexError:
304
- nb_outputs_declared = eqx_list[-2][2]
305
-
306
- if slice_solution is None:
307
- slice_solution = jnp.s_[0:nb_outputs_declared]
308
- if isinstance(slice_solution, int):
309
- # rewrite it as a slice to ensure that axis does not disappear when
310
- # indexing
311
- slice_solution = jnp.s_[slice_solution : slice_solution + 1]
312
-
313
- if input_transform is None:
314
-
315
- def input_transform(_in, _params):
316
- return _in
317
-
318
- if output_transform is None:
319
-
320
- def output_transform(_in_pinn, _out_pinn, _params):
321
- return _out_pinn
322
-
323
- key, subkey = jax.random.split(key, 2)
324
- mlp = _MLP(key=subkey, eqx_list=eqx_list)
325
- # quick partitioning to get the params to get the correct number of neurons
326
- # for the last layer of hyper network
327
- params_mlp, _ = eqx.partition(mlp, eqx.is_inexact_array)
328
- pinn_params_sum, _ = _get_param_nb(params_mlp)
329
- # the number of parameters for the pinn will be the number of ouputs
330
- # for the hyper network
331
- if len(eqx_list_hyper[-1]) > 1:
332
- eqx_list_hyper = eqx_list_hyper[:-1] + (
333
- (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
334
- )
335
- else:
336
- eqx_list_hyper = (
337
- eqx_list_hyper[:-2]
338
- + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
339
- + eqx_list_hyper[-1]
340
- )
341
- if len(eqx_list_hyper[0]) > 1:
342
- eqx_list_hyper = (
343
- (
344
- (eqx_list_hyper[0][0],)
345
- + (hypernet_input_size,)
346
- + (eqx_list_hyper[0][2],)
347
- ),
348
- ) + eqx_list_hyper[1:]
349
- else:
350
- eqx_list_hyper = (
351
- eqx_list_hyper[0]
352
- + (
353
- (
354
- (eqx_list_hyper[1][0],)
355
- + (hypernet_input_size,)
356
- + (eqx_list_hyper[1][2],)
357
- ),
358
- )
359
- + eqx_list_hyper[2:]
360
- )
361
- key, subkey = jax.random.split(key, 2)
362
-
363
- with warnings.catch_warnings():
364
- # TODO check why this warning is raised here and not in the PINN
365
- # context ?
366
- warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
367
- hyper_mlp = _MLP(key=subkey, eqx_list=eqx_list_hyper)
368
-
369
- if shared_pinn_outputs is not None:
370
- hyperpinns = []
371
- for output_slice in shared_pinn_outputs:
372
- with warnings.catch_warnings():
373
- # Catch the equinox warning because we put the number of
374
- # parameters as static while being jnp.Array. This this time
375
- # this is correct to do so, because they are used as indices
376
- # and will never be modified
377
- warnings.filterwarnings(
378
- "ignore", message="A JAX array is being set as static!"
379
- )
380
- hyperpinn = HYPERPINN(
381
- mlp=mlp,
382
- hyper_mlp=hyper_mlp,
383
- slice_solution=slice_solution,
384
- eq_type=eq_type,
385
- input_transform=input_transform,
386
- output_transform=output_transform,
387
- hyperparams=hyperparams,
388
- hypernet_input_size=hypernet_input_size,
389
- output_slice=output_slice,
390
- )
391
- hyperpinns.append(hyperpinn)
392
- return hyperpinns
393
- with warnings.catch_warnings():
394
- # Catch the equinox warning because we put the number of
395
- # parameters as static while being jnp.Array. This this time
396
- # this is correct to do so, because they are used as indices
397
- # and will never be modified
398
- warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
399
- hyperpinn = HYPERPINN(
400
- mlp=mlp,
401
- hyper_mlp=hyper_mlp,
402
- slice_solution=slice_solution,
403
- eq_type=eq_type,
404
- input_transform=input_transform,
405
- output_transform=output_transform,
406
- hyperparams=hyperparams,
407
- hypernet_input_size=hypernet_input_size,
408
- output_slice=None,
409
- )
410
- return hyperpinn