jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/utils/_hyperpinn.py DELETED
@@ -1,420 +0,0 @@
1
- """
2
- Implements utility function to create HYPERPINNs
3
- https://arxiv.org/pdf/2111.01008.pdf
4
- """
5
-
6
- import warnings
7
- from dataclasses import InitVar
8
- from typing import Callable, Literal
9
- import copy
10
- from math import prod
11
- import jax
12
- import jax.numpy as jnp
13
- from jax.tree_util import tree_leaves, tree_map
14
- from jaxtyping import Array, Float, PyTree, Int, Key
15
- import equinox as eqx
16
- import numpy as onp
17
-
18
- from jinns.utils._pinn import PINN, _MLP
19
- from jinns.parameters._params import Params, ParamsDict
20
-
21
-
22
- def _get_param_nb(
23
- params: Params,
24
- ) -> tuple[Int[onp.ndarray, "1"], Int[onp.ndarray, "n_layers"]]:
25
- """Returns the number of parameters in a Params object and also
26
- the cumulative sum when parsing the object.
27
-
28
-
29
- Parameters
30
- ----------
31
- params :
32
- A Params object.
33
- """
34
- dim_prod_all_arrays = [
35
- prod(a.shape)
36
- for a in tree_leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
37
- ]
38
- return onp.asarray(sum(dim_prod_all_arrays)), onp.cumsum(dim_prod_all_arrays)
39
-
40
-
41
- class HYPERPINN(PINN):
42
- """
43
- A HYPERPINN object compatible with the rest of jinns.
44
- Composed of a PINN and an HYPER network. The HYPERPINN is typically
45
- instanciated using with `create_HYPERPINN`. However, a user could directly
46
- creates their HYPERPINN using this
47
- class by passing an eqx.Module for argument `mlp` (resp. for argument
48
- `hyper_mlp`) that plays the role of the NN (resp. hyper NN) and that is
49
- already instanciated.
50
-
51
- Parameters
52
- ----------
53
- hyperparams: list = eqx.field(static=True)
54
- A list of keys from Params.eq_params that will be considered as
55
- hyperparameters for metamodeling.
56
- hypernet_input_size: int
57
- An integer. The input size of the MLP used for the hypernetwork. Must
58
- be equal to the flattened concatenations for the array of parameters
59
- designated by the `hyperparams` argument.
60
- slice_solution : slice
61
- A jnp.s\_ object which indicates which axis of the PINN output is
62
- dedicated to the actual equation solution. Default None
63
- means that slice_solution = the whole PINN output. This argument is useful
64
- when the PINN is also used to output equation parameters for example
65
- Note that it must be a slice and not an integer (a preprocessing of the
66
- user provided argument takes care of it).
67
- eq_type : str
68
- A string with three possibilities.
69
- "ODE": the HYPERPINN is called with one input `t`.
70
- "statio_PDE": the HYPERPINN is called with one input `x`, `x`
71
- can be high dimensional.
72
- "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
73
- can be high dimensional.
74
- **Note**: the input dimension as given in eqx_list has to match the sum
75
- of the dimension of `t` + the dimension of `x` or the output dimension
76
- after the `input_transform` function
77
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
78
- A function that will be called before entering the PINN. Its output(s)
79
- must match the PINN inputs (except for the parameters).
80
- Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
81
- and the parameters. Default is no operation.
82
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
83
- A function with arguments begin the same input as the PINN, the PINN
84
- output and the parameter. This function will be called after exiting the PINN.
85
- Default is no operation.
86
- output_slice : slice, default=None
87
- A jnp.s\_[] to determine the different dimension for the HYPERPINN.
88
- See `shared_pinn_outputs` argument of `create_HYPERPINN`.
89
- mlp : eqx.Module
90
- The actual neural network instanciated as an eqx.Module.
91
- hyper_mlp : eqx.Module
92
- The actual hyper neural network instanciated as an eqx.Module.
93
- """
94
-
95
- hyperparams: list[str] = eqx.field(static=True, kw_only=True)
96
- hypernet_input_size: int = eqx.field(kw_only=True)
97
-
98
- hyper_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
99
- mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
100
-
101
- params_hyper: PyTree = eqx.field(init=False)
102
- static_hyper: PyTree = eqx.field(init=False, static=True)
103
- pinn_params_sum: Int[onp.ndarray, "1"] = eqx.field(init=False, static=True)
104
- pinn_params_cumsum: Int[onp.ndarray, "n_layers"] = eqx.field(
105
- init=False, static=True
106
- )
107
-
108
- def __post_init__(self, mlp, hyper_mlp):
109
- super().__post_init__(
110
- mlp,
111
- )
112
- self.params_hyper, self.static_hyper = eqx.partition(
113
- hyper_mlp, eqx.is_inexact_array
114
- )
115
- self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
116
-
117
- @property
118
- def init_params(self) -> Params:
119
- """
120
- Returns an initial set of parameters
121
- """
122
- return self.params_hyper
123
-
124
- def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
125
- """
126
- From the output of the hypernetwork we set the well formed
127
- parameters of the pinn (`self.params`)
128
- """
129
- pinn_params_flat = eqx.tree_at(
130
- lambda p: tree_leaves(p, is_leaf=eqx.is_array),
131
- self.params,
132
- jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
133
- )
134
-
135
- return tree_map(
136
- lambda a, b: a.reshape(b.shape),
137
- pinn_params_flat,
138
- self.params,
139
- is_leaf=lambda x: isinstance(x, jnp.ndarray),
140
- )
141
-
142
- def __call__(
143
- self,
144
- inputs: Float[Array, "input_dim"],
145
- params: Params | ParamsDict | PyTree,
146
- ) -> Float[Array, "output_dim"]:
147
- """
148
- Evaluate the HyperPINN on some inputs with some params.
149
- """
150
- if len(inputs.shape) == 0:
151
- # This can happen often when the user directly provides some
152
- # collocation points (eg for plotting, whithout using
153
- # DataGenerators)
154
- inputs = inputs[None]
155
-
156
- try:
157
- hyper = eqx.combine(params.nn_params, self.static_hyper)
158
- except (KeyError, AttributeError, TypeError) as e: # give more flexibility
159
- hyper = eqx.combine(params, self.static_hyper)
160
-
161
- eq_params_batch = jnp.concatenate(
162
- [params.eq_params[k].flatten() for k in self.hyperparams], axis=0
163
- )
164
-
165
- hyper_output = hyper(eq_params_batch)
166
-
167
- pinn_params = self._hyper_to_pinn(hyper_output)
168
-
169
- pinn = eqx.combine(pinn_params, self.static)
170
- res = self.output_transform(
171
- inputs, pinn(self.input_transform(inputs, params)).squeeze(), params
172
- )
173
-
174
- if self.output_slice is not None:
175
- res = res[self.output_slice]
176
-
177
- ## force (1,) output for non vectorial solution (consistency)
178
- if not res.shape:
179
- return jnp.expand_dims(res, axis=-1)
180
- return res
181
-
182
-
183
- def create_HYPERPINN(
184
- key: Key,
185
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
186
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
187
- hyperparams: list[str],
188
- hypernet_input_size: int,
189
- dim_x: int = 0,
190
- input_transform: Callable[
191
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
192
- ] = None,
193
- output_transform: Callable[
194
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
195
- Float[Array, "output_dim"],
196
- ] = None,
197
- slice_solution: slice = None,
198
- shared_pinn_outputs: slice = None,
199
- eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
200
- ) -> tuple[HYPERPINN | list[HYPERPINN], PyTree | list[PyTree]]:
201
- r"""
202
- Utility function to create a standard PINN neural network with the equinox
203
- library.
204
-
205
- Parameters
206
- ----------
207
- key
208
- A JAX random key that will be used to initialize the network
209
- parameters.
210
- eqx_list
211
- A tuple of tuples of successive equinox modules and activation functions to
212
- describe the PINN architecture. The inner tuples must have the eqx module or
213
- activation function as first item, other items represent arguments
214
- that could be required (eg. the size of the layer).
215
- The `key` argument need not be given.
216
- Thus typical example is `eqx_list=
217
- ((eqx.nn.Linear, 2, 20),
218
- jax.nn.tanh,
219
- (eqx.nn.Linear, 20, 20),
220
- jax.nn.tanh,
221
- (eqx.nn.Linear, 20, 20),
222
- jax.nn.tanh,
223
- (eqx.nn.Linear, 20, 1)
224
- )`.
225
- eq_type
226
- A string with three possibilities.
227
- "ODE": the HYPERPINN is called with one input `t`.
228
- "statio_PDE": the HYPERPINN is called with one input `x`, `x`
229
- can be high dimensional.
230
- "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
231
- can be high dimensional.
232
- **Note**: the input dimension as given in eqx_list has to match the sum
233
- of the dimension of `t` + the dimension of `x` or the output dimension
234
- after the `input_transform` function
235
- hyperparams
236
- A list of keys from Params.eq_params that will be considered as
237
- hyperparameters for metamodeling.
238
- hypernet_input_size
239
- An integer. The input size of the MLP used for the hypernetwork. Must
240
- be equal to the flattened concatenations for the array of parameters
241
- designated by the `hyperparams` argument.
242
- dim_x
243
- An integer. The dimension of `x`. Default `0`.
244
- input_transform
245
- A function that will be called before entering the PINN. Its output(s)
246
- must match the PINN inputs (except for the parameters).
247
- Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
248
- and the parameters. Default is no operation.
249
- output_transform
250
- A function with arguments begin the same input as the PINN, the PINN
251
- output and the parameter. This function will be called after exiting the PINN.
252
- Default is no operation.
253
- slice_solution
254
- A jnp.s\_ object which indicates which axis of the PINN output is
255
- dedicated to the actual equation solution. Default None
256
- means that slice_solution = the whole PINN output. This argument is useful
257
- when the PINN is also used to output equation parameters for example
258
- Note that it must be a slice and not an integer (a preprocessing of the
259
- user provided argument takes care of it).
260
- shared_pinn_outputs
261
- Default is None, for a stantard PINN.
262
- A tuple of jnp.s\_[] (slices) to determine the different output for each
263
- network. In this case we return a list of PINNs, one for each output in
264
- shared_pinn_outputs. This is useful to create PINNs that share the
265
- same network and same parameters; **the user must then use the same
266
- parameter set in their manipulation**.
267
- See the notebook 2D Navier Stokes in pipeflow with metamodel for an
268
- example using this option.
269
- eqx_list_hyper
270
- Same as eqx_list but for the hypernetwork. Default is None, i.e., we
271
- use the same architecture as the PINN, up to the number of inputs and
272
- ouputs. Note that the number of inputs must be of the hypernetwork must
273
- be equal to the flattened concatenations for the array of parameters
274
- designated by the `hyperparams` argument;
275
- and the number of outputs must be equal to the number
276
- of parameters in the pinn network
277
-
278
- Returns
279
- -------
280
- hyperpinn
281
- A HYPERPINN instance or, when `shared_pinn_ouput` is not None,
282
- a list of HYPERPINN instances with the same structure is returned,
283
- only differing by there final slicing of the network output.
284
- hyperpinn.init_params
285
- The initial set of parameters for the HyperPINN or a list of the latter
286
- when `shared_pinn_ouput` is not None.
287
-
288
-
289
- Raises
290
- ------
291
- RuntimeError
292
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
293
- "nonstatio_PDE"]`
294
- RuntimeError
295
- If we have a `dim_x > 0` and `eq_type == "ODE"`
296
- or if we have a `dim_x = 0` and `eq_type != "ODE"`
297
- """
298
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
299
- raise RuntimeError("Wrong parameter value for eq_type")
300
-
301
- if eq_type == "ODE" and dim_x != 0:
302
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
303
-
304
- if eq_type != "ODE" and dim_x == 0:
305
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
306
-
307
- if eqx_list_hyper is None:
308
- eqx_list_hyper = copy.deepcopy(eqx_list)
309
-
310
- try:
311
- nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
312
- # last layer
313
- except IndexError:
314
- nb_outputs_declared = eqx_list[-2][2]
315
-
316
- if slice_solution is None:
317
- slice_solution = jnp.s_[0:nb_outputs_declared]
318
- if isinstance(slice_solution, int):
319
- # rewrite it as a slice to ensure that axis does not disappear when
320
- # indexing
321
- slice_solution = jnp.s_[slice_solution : slice_solution + 1]
322
-
323
- if input_transform is None:
324
-
325
- def input_transform(_in, _params):
326
- return _in
327
-
328
- if output_transform is None:
329
-
330
- def output_transform(_in_pinn, _out_pinn, _params):
331
- return _out_pinn
332
-
333
- key, subkey = jax.random.split(key, 2)
334
- mlp = _MLP(key=subkey, eqx_list=eqx_list)
335
- # quick partitioning to get the params to get the correct number of neurons
336
- # for the last layer of hyper network
337
- params_mlp, _ = eqx.partition(mlp, eqx.is_inexact_array)
338
- pinn_params_sum, _ = _get_param_nb(params_mlp)
339
- # the number of parameters for the pinn will be the number of ouputs
340
- # for the hyper network
341
- if len(eqx_list_hyper[-1]) > 1:
342
- eqx_list_hyper = eqx_list_hyper[:-1] + (
343
- (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
344
- )
345
- else:
346
- eqx_list_hyper = (
347
- eqx_list_hyper[:-2]
348
- + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
349
- + eqx_list_hyper[-1]
350
- )
351
- if len(eqx_list_hyper[0]) > 1:
352
- eqx_list_hyper = (
353
- (
354
- (eqx_list_hyper[0][0],)
355
- + (hypernet_input_size,)
356
- + (eqx_list_hyper[0][2],)
357
- ),
358
- ) + eqx_list_hyper[1:]
359
- else:
360
- eqx_list_hyper = (
361
- eqx_list_hyper[0]
362
- + (
363
- (
364
- (eqx_list_hyper[1][0],)
365
- + (hypernet_input_size,)
366
- + (eqx_list_hyper[1][2],)
367
- ),
368
- )
369
- + eqx_list_hyper[2:]
370
- )
371
- key, subkey = jax.random.split(key, 2)
372
-
373
- with warnings.catch_warnings():
374
- # TODO check why this warning is raised here and not in the PINN
375
- # context ?
376
- warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
377
- hyper_mlp = _MLP(key=subkey, eqx_list=eqx_list_hyper)
378
-
379
- if shared_pinn_outputs is not None:
380
- hyperpinns = []
381
- for output_slice in shared_pinn_outputs:
382
- with warnings.catch_warnings():
383
- # Catch the equinox warning because we put the number of
384
- # parameters as static while being jnp.Array. This this time
385
- # this is correct to do so, because they are used as indices
386
- # and will never be modified
387
- warnings.filterwarnings(
388
- "ignore", message="A JAX array is being set as static!"
389
- )
390
- hyperpinn = HYPERPINN(
391
- mlp=mlp,
392
- hyper_mlp=hyper_mlp,
393
- slice_solution=slice_solution,
394
- eq_type=eq_type,
395
- input_transform=input_transform,
396
- output_transform=output_transform,
397
- hyperparams=hyperparams,
398
- hypernet_input_size=hypernet_input_size,
399
- output_slice=output_slice,
400
- )
401
- hyperpinns.append(hyperpinn)
402
- return hyperpinns, [h.init_params for h in hyperpinns]
403
- with warnings.catch_warnings():
404
- # Catch the equinox warning because we put the number of
405
- # parameters as static while being jnp.Array. This this time
406
- # this is correct to do so, because they are used as indices
407
- # and will never be modified
408
- warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
409
- hyperpinn = HYPERPINN(
410
- mlp=mlp,
411
- hyper_mlp=hyper_mlp,
412
- slice_solution=slice_solution,
413
- eq_type=eq_type,
414
- input_transform=input_transform,
415
- output_transform=output_transform,
416
- hyperparams=hyperparams,
417
- hypernet_input_size=hypernet_input_size,
418
- output_slice=None,
419
- )
420
- return hyperpinn, hyperpinn.init_params