jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/utils/_hyperpinn.py CHANGED
@@ -3,76 +3,124 @@ Implements utility function to create HYPERPINNs
3
3
  https://arxiv.org/pdf/2111.01008.pdf
4
4
  """
5
5
 
6
+ import warnings
7
+ from dataclasses import InitVar
8
+ from typing import Callable, Literal
6
9
  import copy
7
10
  from math import prod
8
- import numpy as onp
9
11
  import jax
10
12
  import jax.numpy as jnp
11
13
  from jax.tree_util import tree_leaves, tree_map
12
- from jax.typing import ArrayLike
14
+ from jaxtyping import Array, Float, PyTree, Int, Key
13
15
  import equinox as eqx
16
+ import numpy as onp
14
17
 
15
18
  from jinns.utils._pinn import PINN, _MLP
19
+ from jinns.parameters._params import Params
16
20
 
17
21
 
18
- def _get_param_nb(params):
19
- """
20
- Returns the number of parameters in a equinox module whose parameters
21
- are stored in the pytree of parameters params but also the cumulative
22
- sum when parsing the pytree
23
- In reality, multiply the dimensions of the Arrays in this pytree and
24
- sum everything, using pytree utility functions
22
+ def _get_param_nb(
23
+ params: Params,
24
+ ) -> tuple[Int[onp.ndarray, "1"], Int[onp.ndarray, "n_layers"]]:
25
+ """Returns the number of parameters in a Params object and also
26
+ the cumulative sum when parsing the object.
27
+
28
+
29
+ Parameters
30
+ ----------
31
+ params :
32
+ A Params object.
25
33
  """
26
34
  dim_prod_all_arrays = [
27
35
  prod(a.shape)
28
36
  for a in tree_leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
29
37
  ]
30
- return sum(dim_prod_all_arrays), onp.cumsum(dim_prod_all_arrays)
38
+ return onp.asarray(sum(dim_prod_all_arrays)), onp.cumsum(dim_prod_all_arrays)
31
39
 
32
40
 
33
41
  class HYPERPINN(PINN):
34
42
  """
35
- Composed of a PINN and an hypernetwork
36
- """
43
+ A HYPERPINN object compatible with the rest of jinns.
44
+ Composed of a PINN and an HYPER network. The HYPERPINN is typically
45
+ instanciated using with `create_HYPERPINN`. However, a user could directly
46
+ creates their HYPERPINN using this
47
+ class by passing an eqx.Module for argument `mlp` (resp. for argument
48
+ `hyper_mlp`) that plays the role of the NN (resp. hyper NN) and that is
49
+ already instanciated.
37
50
 
38
- params_hyper: eqx.Module
39
- static_hyper: eqx.Module
51
+ Parameters
52
+ ----------
40
53
  hyperparams: list = eqx.field(static=True)
54
+ A list of keys from Params.eq_params that will be considered as
55
+ hyperparameters for metamodeling.
41
56
  hypernet_input_size: int
42
- pinn_params_sum: ArrayLike = eqx.field(static=True)
43
- pinn_params_cumsum: ArrayLike = eqx.field(static=True)
57
+ An integer. The input size of the MLP used for the hypernetwork. Must
58
+ be equal to the flattened concatenations for the array of parameters
59
+ designated by the `hyperparams` argument.
60
+ slice_solution : slice
61
+ A jnp.s\_ object which indicates which axis of the PINN output is
62
+ dedicated to the actual equation solution. Default None
63
+ means that slice_solution = the whole PINN output. This argument is useful
64
+ when the PINN is also used to output equation parameters for example
65
+ Note that it must be a slice and not an integer (a preprocessing of the
66
+ user provided argument takes care of it).
67
+ eq_type : str
68
+ A string with three possibilities.
69
+ "ODE": the HYPERPINN is called with one input `t`.
70
+ "statio_PDE": the HYPERPINN is called with one input `x`, `x`
71
+ can be high dimensional.
72
+ "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
73
+ can be high dimensional.
74
+ **Note**: the input dimension as given in eqx_list has to match the sum
75
+ of the dimension of `t` + the dimension of `x` or the output dimension
76
+ after the `input_transform` function
77
+ input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
78
+ A function that will be called before entering the PINN. Its output(s)
79
+ must match the PINN inputs (except for the parameters).
80
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
81
+ and the parameters. Default is no operation.
82
+ output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
83
+ A function with arguments begin the same input as the PINN, the PINN
84
+ output and the parameter. This function will be called after exiting the PINN.
85
+ Default is no operation.
86
+ output_slice : slice, default=None
87
+ A jnp.s\_[] to determine the different dimension for the HYPERPINN.
88
+ See `shared_pinn_outputs` argument of `create_HYPERPINN`.
89
+ mlp : eqx.Module
90
+ The actual neural network instanciated as an eqx.Module.
91
+ hyper_mlp : eqx.Module
92
+ The actual hyper neural network instanciated as an eqx.Module.
93
+ """
44
94
 
45
- def __init__(
46
- self,
47
- mlp,
48
- hyper_mlp,
49
- slice_solution,
50
- eq_type,
51
- input_transform,
52
- output_transform,
53
- hyperparams,
54
- hypernet_input_size,
55
- output_slice,
56
- ):
57
- super().__init__(
95
+ hyperparams: list[str] = eqx.field(static=True, kw_only=True)
96
+ hypernet_input_size: int = eqx.field(kw_only=True)
97
+
98
+ hyper_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
99
+ mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
100
+
101
+ params_hyper: PyTree = eqx.field(init=False)
102
+ static_hyper: PyTree = eqx.field(init=False, static=True)
103
+ pinn_params_sum: Int[onp.ndarray, "1"] = eqx.field(init=False, static=True)
104
+ pinn_params_cumsum: Int[onp.ndarray, "n_layers"] = eqx.field(
105
+ init=False, static=True
106
+ )
107
+
108
+ def __post_init__(self, mlp, hyper_mlp):
109
+ super().__post_init__(
58
110
  mlp,
59
- slice_solution,
60
- eq_type,
61
- input_transform,
62
- output_transform,
63
- output_slice,
64
111
  )
65
112
  self.params_hyper, self.static_hyper = eqx.partition(
66
113
  hyper_mlp, eqx.is_inexact_array
67
114
  )
68
- self.hyperparams = hyperparams
69
- self.hypernet_input_size = hypernet_input_size
70
115
  self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
71
116
 
72
- def init_params(self):
117
+ def init_params(self) -> Params:
118
+ """
119
+ Returns an initial set of parameters
120
+ """
73
121
  return self.params_hyper
74
122
 
75
- def hyper_to_pinn(self, hyper_output):
123
+ def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
76
124
  """
77
125
  From the output of the hypernetwork we set the well formed
78
126
  parameters of the pinn (`self.params`)
@@ -90,26 +138,31 @@ class HYPERPINN(PINN):
90
138
  is_leaf=lambda x: isinstance(x, jnp.ndarray),
91
139
  )
92
140
 
93
- def _eval_nn(self, inputs, params, input_transform, output_transform):
141
+ def eval_nn(
142
+ self,
143
+ inputs: Float[Array, "input_dim"],
144
+ params: Params | PyTree,
145
+ ) -> Float[Array, "output_dim"]:
94
146
  """
95
- inner function to factorize code. apply_fn (which takes varying forms)
96
- call _eval_nn which always have the same content.
147
+ Evaluate the HYPERPINN on some inputs with some params.
97
148
  """
98
149
  try:
99
- hyper = eqx.combine(params["nn_params"], self.static_hyper)
100
- except (KeyError, TypeError) as e: # give more flexibility
150
+ hyper = eqx.combine(params.nn_params, self.static_hyper)
151
+ except (KeyError, AttributeError, TypeError) as e: # give more flexibility
101
152
  hyper = eqx.combine(params, self.static_hyper)
102
153
 
103
154
  eq_params_batch = jnp.concatenate(
104
- [params["eq_params"][k].flatten() for k in self.hyperparams], axis=0
155
+ [params.eq_params[k].flatten() for k in self.hyperparams], axis=0
105
156
  )
106
157
 
107
158
  hyper_output = hyper(eq_params_batch)
108
159
 
109
- pinn_params = self.hyper_to_pinn(hyper_output)
160
+ pinn_params = self._hyper_to_pinn(hyper_output)
110
161
 
111
162
  pinn = eqx.combine(pinn_params, self.static)
112
- res = output_transform(inputs, pinn(input_transform(inputs, params)).squeeze())
163
+ res = self.output_transform(
164
+ inputs, pinn(self.input_transform(inputs, params)).squeeze(), params
165
+ )
113
166
 
114
167
  if self.output_slice is not None:
115
168
  res = res[self.output_slice]
@@ -121,18 +174,23 @@ class HYPERPINN(PINN):
121
174
 
122
175
 
123
176
  def create_HYPERPINN(
124
- key,
125
- eqx_list,
126
- eq_type,
127
- hyperparams,
128
- hypernet_input_size,
129
- dim_x=0,
130
- input_transform=None,
131
- output_transform=None,
132
- slice_solution=None,
133
- shared_pinn_outputs=None,
134
- eqx_list_hyper=None,
135
- ):
177
+ key: Key,
178
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
179
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
180
+ hyperparams: list[str],
181
+ hypernet_input_size: int,
182
+ dim_x: int = 0,
183
+ input_transform: Callable[
184
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
185
+ ] = None,
186
+ output_transform: Callable[
187
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
188
+ Float[Array, "output_dim"],
189
+ ] = None,
190
+ slice_solution: slice = None,
191
+ shared_pinn_outputs: slice = None,
192
+ eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
193
+ ) -> HYPERPINN | list[HYPERPINN]:
136
194
  r"""
137
195
  Utility function to create a standard PINN neural network with the equinox
138
196
  library.
@@ -140,59 +198,61 @@ def create_HYPERPINN(
140
198
  Parameters
141
199
  ----------
142
200
  key
143
- A jax random key that will be used to initialize the network parameters
201
+ A JAX random key that will be used to initialize the network
202
+ parameters.
144
203
  eqx_list
145
- A list of list of successive equinox modules and activation functions to
146
- describe the PINN architecture. The inner lists have the eqx module or
147
- axtivation function as first item, other items represents arguments
204
+ A tuple of tuples of successive equinox modules and activation functions to
205
+ describe the PINN architecture. The inner tuples must have the eqx module or
206
+ activation function as first item, other items represent arguments
148
207
  that could be required (eg. the size of the layer).
149
- __Note:__ the `key` argument need not be given.
208
+ The `key` argument need not be given.
150
209
  Thus typical example is `eqx_list=
151
- [[eqx.nn.Linear, 2, 20],
152
- [jax.nn.tanh],
153
- [eqx.nn.Linear, 20, 20],
154
- [jax.nn.tanh],
155
- [eqx.nn.Linear, 20, 20],
156
- [jax.nn.tanh],
157
- [eqx.nn.Linear, 20, 1]
158
- ]`
210
+ ((eqx.nn.Linear, 2, 20),
211
+ jax.nn.tanh,
212
+ (eqx.nn.Linear, 20, 20),
213
+ jax.nn.tanh,
214
+ (eqx.nn.Linear, 20, 20),
215
+ jax.nn.tanh,
216
+ (eqx.nn.Linear, 20, 1)
217
+ )`.
159
218
  eq_type
160
219
  A string with three possibilities.
161
- "ODE": the PINN is called with one input `t`.
162
- "statio_PDE": the PINN is called with one input `x`, `x`
220
+ "ODE": the HYPERPINN is called with one input `t`.
221
+ "statio_PDE": the HYPERPINN is called with one input `x`, `x`
163
222
  can be high dimensional.
164
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
223
+ "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
165
224
  can be high dimensional.
166
225
  **Note**: the input dimension as given in eqx_list has to match the sum
167
226
  of the dimension of `t` + the dimension of `x` or the output dimension
168
227
  after the `input_transform` function
169
228
  hyperparams
170
- A list of keys from params["eq_params"] that will be considered as
171
- hyperparameters for metamodeling
229
+ A list of keys from Params.eq_params that will be considered as
230
+ hyperparameters for metamodeling.
172
231
  hypernet_input_size
173
232
  An integer. The input size of the MLP used for the hypernetwork. Must
174
233
  be equal to the flattened concatenations for the array of parameters
175
- designated by the `hyperparams` argument
234
+ designated by the `hyperparams` argument.
176
235
  dim_x
177
- An integer. The dimension of `x`. Default `0`
236
+ An integer. The dimension of `x`. Default `0`.
178
237
  input_transform
179
238
  A function that will be called before entering the PINN. Its output(s)
180
- must match the PINN inputs. Its inputs are the PINN inputs (`t` and/or
181
- `x` concatenated together and the parameters). Default is the No operation
239
+ must match the PINN inputs (except for the parameters).
240
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
241
+ and the parameters. Default is no operation.
182
242
  output_transform
183
- A function with arguments the same input(s) as the PINN AND the PINN
184
- output that will be called after exiting the PINN. Default is the No
185
- operation
243
+ A function with arguments begin the same input as the PINN, the PINN
244
+ output and the parameter. This function will be called after exiting the PINN.
245
+ Default is no operation.
186
246
  slice_solution
187
247
  A jnp.s\_ object which indicates which axis of the PINN output is
188
248
  dedicated to the actual equation solution. Default None
189
249
  means that slice_solution = the whole PINN output. This argument is useful
190
250
  when the PINN is also used to output equation parameters for example
191
251
  Note that it must be a slice and not an integer (a preprocessing of the
192
- user provided argument takes care of it)
252
+ user provided argument takes care of it).
193
253
  shared_pinn_outputs
194
254
  Default is None, for a stantard PINN.
195
- A tuple of jnp.s_[] (slices) to determine the different output for each
255
+ A tuple of jnp.s\_[] (slices) to determine the different output for each
196
256
  network. In this case we return a list of PINNs, one for each output in
197
257
  shared_pinn_outputs. This is useful to create PINNs that share the
198
258
  same network and same parameters; **the user must then use the same
@@ -210,7 +270,11 @@ def create_HYPERPINN(
210
270
 
211
271
  Returns
212
272
  -------
213
- `u`, a :class:`.HyperPINN` object which inherits from `eqx.Module` (hence callable).
273
+ hyperpinn
274
+ A HYPERPINN instance or, when `shared_pinn_ouput` is not None,
275
+ a list of HYPERPINN instances with the same structure is returned,
276
+ only differing by there final slicing of the network output.
277
+
214
278
 
215
279
  Raises
216
280
  ------
@@ -253,53 +317,94 @@ def create_HYPERPINN(
253
317
 
254
318
  if output_transform is None:
255
319
 
256
- def output_transform(_in_pinn, _out_pinn):
320
+ def output_transform(_in_pinn, _out_pinn, _params):
257
321
  return _out_pinn
258
322
 
259
323
  key, subkey = jax.random.split(key, 2)
260
- mlp = _MLP(subkey, eqx_list)
324
+ mlp = _MLP(key=subkey, eqx_list=eqx_list)
261
325
  # quick partitioning to get the params to get the correct number of neurons
262
326
  # for the last layer of hyper network
263
327
  params_mlp, _ = eqx.partition(mlp, eqx.is_inexact_array)
264
328
  pinn_params_sum, _ = _get_param_nb(params_mlp)
265
329
  # the number of parameters for the pinn will be the number of ouputs
266
330
  # for the hyper network
267
- try:
268
- eqx_list_hyper[-1][2] = pinn_params_sum
269
- except IndexError:
270
- eqx_list_hyper[-2][2] = pinn_params_sum
271
- try:
272
- eqx_list_hyper[0][1] = hypernet_input_size
273
- except IndexError:
274
- eqx_list_hyper[1][1] = hypernet_input_size
331
+ if len(eqx_list_hyper[-1]) > 1:
332
+ eqx_list_hyper = eqx_list_hyper[:-1] + (
333
+ (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
334
+ )
335
+ else:
336
+ eqx_list_hyper = (
337
+ eqx_list_hyper[:-2]
338
+ + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
339
+ + eqx_list_hyper[-1]
340
+ )
341
+ if len(eqx_list_hyper[0]) > 1:
342
+ eqx_list_hyper = (
343
+ (
344
+ (eqx_list_hyper[0][0],)
345
+ + (hypernet_input_size,)
346
+ + (eqx_list_hyper[0][2],)
347
+ ),
348
+ ) + eqx_list_hyper[1:]
349
+ else:
350
+ eqx_list_hyper = (
351
+ eqx_list_hyper[0]
352
+ + (
353
+ (
354
+ (eqx_list_hyper[1][0],)
355
+ + (hypernet_input_size,)
356
+ + (eqx_list_hyper[1][2],)
357
+ ),
358
+ )
359
+ + eqx_list_hyper[2:]
360
+ )
275
361
  key, subkey = jax.random.split(key, 2)
276
- hyper_mlp = _MLP(subkey, eqx_list_hyper)
362
+
363
+ with warnings.catch_warnings():
364
+ # TODO check why this warning is raised here and not in the PINN
365
+ # context ?
366
+ warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
367
+ hyper_mlp = _MLP(key=subkey, eqx_list=eqx_list_hyper)
277
368
 
278
369
  if shared_pinn_outputs is not None:
279
370
  hyperpinns = []
280
371
  for output_slice in shared_pinn_outputs:
281
- hyperpinn = HYPERPINN(
282
- mlp,
283
- hyper_mlp,
284
- slice_solution,
285
- eq_type,
286
- input_transform,
287
- output_transform,
288
- hyperparams,
289
- hypernet_input_size,
290
- output_slice,
291
- )
372
+ with warnings.catch_warnings():
373
+ # Catch the equinox warning because we put the number of
374
+ # parameters as static while being jnp.Array. This this time
375
+ # this is correct to do so, because they are used as indices
376
+ # and will never be modified
377
+ warnings.filterwarnings(
378
+ "ignore", message="A JAX array is being set as static!"
379
+ )
380
+ hyperpinn = HYPERPINN(
381
+ mlp=mlp,
382
+ hyper_mlp=hyper_mlp,
383
+ slice_solution=slice_solution,
384
+ eq_type=eq_type,
385
+ input_transform=input_transform,
386
+ output_transform=output_transform,
387
+ hyperparams=hyperparams,
388
+ hypernet_input_size=hypernet_input_size,
389
+ output_slice=output_slice,
390
+ )
292
391
  hyperpinns.append(hyperpinn)
293
392
  return hyperpinns
294
- hyperpinn = HYPERPINN(
295
- mlp,
296
- hyper_mlp,
297
- slice_solution,
298
- eq_type,
299
- input_transform,
300
- output_transform,
301
- hyperparams,
302
- hypernet_input_size,
303
- None,
304
- )
393
+ with warnings.catch_warnings():
394
+ # Catch the equinox warning because we put the number of
395
+ # parameters as static while being jnp.Array. This this time
396
+ # this is correct to do so, because they are used as indices
397
+ # and will never be modified
398
+ warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
399
+ hyperpinn = HYPERPINN(
400
+ mlp=mlp,
401
+ hyper_mlp=hyper_mlp,
402
+ slice_solution=slice_solution,
403
+ eq_type=eq_type,
404
+ input_transform=input_transform,
405
+ output_transform=output_transform,
406
+ hyperparams=hyperparams,
407
+ hypernet_input_size=hypernet_input_size,
408
+ output_slice=None,
409
+ )
305
410
  return hyperpinn