jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,57 +1,51 @@
1
1
  """
2
- NamedTuples definition
2
+ equinox Modules used as containers
3
3
  """
4
4
 
5
- from typing import Union, NamedTuple
6
- from jaxtyping import PyTree
7
- from jax.typing import ArrayLike
8
- import optax
9
- import jax.numpy as jnp
10
- from jinns.loss._LossODE import LossODE, SystemLossODE
11
- from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
12
- from jinns.data._DataGenerators import (
13
- DataGeneratorODE,
14
- CubicMeshPDEStatio,
15
- CubicMeshPDENonStatio,
16
- DataGeneratorParameter,
17
- DataGeneratorObservations,
18
- DataGeneratorObservationsMultiPINNs,
19
- )
20
-
21
-
22
- class DataGeneratorContainer(NamedTuple):
23
- data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
24
- param_data: Union[DataGeneratorParameter, None] = None
25
- obs_data: Union[
26
- DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
27
- ] = None
28
-
29
-
30
- class ValidationContainer(NamedTuple):
31
- loss: Union[
32
- LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
33
- ]
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Dict
10
+ from jaxtyping import PyTree, Array, Float, Bool
11
+ from optax import OptState
12
+ import equinox as eqx
13
+
14
+ if TYPE_CHECKING:
15
+ from jinns.utils._types import *
16
+
17
+
18
+ class DataGeneratorContainer(eqx.Module):
19
+ data: AnyDataGenerator
20
+ param_data: DataGeneratorParameter | None = None
21
+ obs_data: DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None = (
22
+ None
23
+ )
24
+
25
+
26
+ class ValidationContainer(eqx.Module):
27
+ loss: AnyLoss | None
34
28
  data: DataGeneratorContainer
35
29
  hyperparams: PyTree = None
36
- loss_values: Union[ArrayLike, None] = None
30
+ loss_values: Float[Array, "n_iter"] | None = None
37
31
 
38
32
 
39
- class OptimizationContainer(NamedTuple):
40
- params: dict
41
- last_non_nan_params: dict
42
- opt_state: optax.OptState
33
+ class OptimizationContainer(eqx.Module):
34
+ params: Params
35
+ last_non_nan_params: Params
36
+ opt_state: OptState
43
37
 
44
38
 
45
- class OptimizationExtraContainer(NamedTuple):
39
+ class OptimizationExtraContainer(eqx.Module):
46
40
  curr_seq: int
47
- seq2seq: Union[dict, None]
48
- early_stopping: bool = False
41
+ best_val_params: Params
42
+ early_stopping: Bool = False
49
43
 
50
44
 
51
- class LossContainer(NamedTuple):
52
- stored_loss_terms: dict
53
- train_loss_values: ArrayLike
45
+ class LossContainer(eqx.Module):
46
+ stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
47
+ train_loss_values: Float[Array, "n_iter"]
54
48
 
55
49
 
56
- class StoredObjectContainer(NamedTuple):
57
- stored_params: Union[list, None]
50
+ class StoredObjectContainer(eqx.Module):
51
+ stored_params: list | None
jinns/utils/_hyperpinn.py CHANGED
@@ -3,90 +3,132 @@ Implements utility function to create HYPERPINNs
3
3
  https://arxiv.org/pdf/2111.01008.pdf
4
4
  """
5
5
 
6
+ import warnings
7
+ from dataclasses import InitVar
8
+ from typing import Callable, Literal
6
9
  import copy
7
10
  from math import prod
8
- import numpy as onp
9
11
  import jax
10
12
  import jax.numpy as jnp
11
13
  from jax.tree_util import tree_leaves, tree_map
12
- from jax.typing import ArrayLike
14
+ from jaxtyping import Array, Float, PyTree, Int, Key
13
15
  import equinox as eqx
16
+ import numpy as onp
14
17
 
15
18
  from jinns.utils._pinn import PINN, _MLP
19
+ from jinns.parameters._params import Params
16
20
 
17
21
 
18
- def _get_param_nb(params):
19
- """
20
- Returns the number of parameters in a equinox module whose parameters
21
- are stored in the pytree of parameters params but also the cumulative
22
- sum when parsing the pytree
23
- In reality, multiply the dimensions of the Arrays in this pytree and
24
- sum everything, using pytree utility functions
22
+ def _get_param_nb(
23
+ params: Params,
24
+ ) -> tuple[Int[onp.ndarray, "1"], Int[onp.ndarray, "n_layers"]]:
25
+ """Returns the number of parameters in a Params object and also
26
+ the cumulative sum when parsing the object.
27
+
28
+
29
+ Parameters
30
+ ----------
31
+ params :
32
+ A Params object.
25
33
  """
26
34
  dim_prod_all_arrays = [
27
35
  prod(a.shape)
28
36
  for a in tree_leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
29
37
  ]
30
- return sum(dim_prod_all_arrays), onp.cumsum(dim_prod_all_arrays)
38
+ return onp.asarray(sum(dim_prod_all_arrays)), onp.cumsum(dim_prod_all_arrays)
31
39
 
32
40
 
33
41
  class HYPERPINN(PINN):
34
42
  """
35
- Composed of a PINN and an hypernetwork
36
- """
43
+ A HYPERPINN object compatible with the rest of jinns.
44
+ Composed of a PINN and an HYPER network. The HYPERPINN is typically
45
+ instanciated using with `create_HYPERPINN`. However, a user could directly
46
+ creates their HYPERPINN using this
47
+ class by passing an eqx.Module for argument `mlp` (resp. for argument
48
+ `hyper_mlp`) that plays the role of the NN (resp. hyper NN) and that is
49
+ already instanciated.
37
50
 
38
- params_hyper: eqx.Module
39
- static_hyper: eqx.Module
51
+ Parameters
52
+ ----------
40
53
  hyperparams: list = eqx.field(static=True)
54
+ A list of keys from Params.eq_params that will be considered as
55
+ hyperparameters for metamodeling.
41
56
  hypernet_input_size: int
42
- pinn_params_sum: ArrayLike = eqx.field(static=True)
43
- pinn_params_cumsum: ArrayLike = eqx.field(static=True)
57
+ An integer. The input size of the MLP used for the hypernetwork. Must
58
+ be equal to the flattened concatenations for the array of parameters
59
+ designated by the `hyperparams` argument.
60
+ slice_solution : slice
61
+ A jnp.s\_ object which indicates which axis of the PINN output is
62
+ dedicated to the actual equation solution. Default None
63
+ means that slice_solution = the whole PINN output. This argument is useful
64
+ when the PINN is also used to output equation parameters for example
65
+ Note that it must be a slice and not an integer (a preprocessing of the
66
+ user provided argument takes care of it).
67
+ eq_type : str
68
+ A string with three possibilities.
69
+ "ODE": the HYPERPINN is called with one input `t`.
70
+ "statio_PDE": the HYPERPINN is called with one input `x`, `x`
71
+ can be high dimensional.
72
+ "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
73
+ can be high dimensional.
74
+ **Note**: the input dimension as given in eqx_list has to match the sum
75
+ of the dimension of `t` + the dimension of `x` or the output dimension
76
+ after the `input_transform` function
77
+ input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
78
+ A function that will be called before entering the PINN. Its output(s)
79
+ must match the PINN inputs (except for the parameters).
80
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
81
+ and the parameters. Default is no operation.
82
+ output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
83
+ A function with arguments begin the same input as the PINN, the PINN
84
+ output and the parameter. This function will be called after exiting the PINN.
85
+ Default is no operation.
86
+ output_slice : slice, default=None
87
+ A jnp.s\_[] to determine the different dimension for the HYPERPINN.
88
+ See `shared_pinn_outputs` argument of `create_HYPERPINN`.
89
+ mlp : eqx.Module
90
+ The actual neural network instanciated as an eqx.Module.
91
+ hyper_mlp : eqx.Module
92
+ The actual hyper neural network instanciated as an eqx.Module.
93
+ """
44
94
 
45
- def __init__(
46
- self,
47
- mlp,
48
- hyper_mlp,
49
- slice_solution,
50
- eq_type,
51
- input_transform,
52
- output_transform,
53
- hyperparams,
54
- hypernet_input_size,
55
- output_slice,
56
- ):
57
- super().__init__(
95
+ hyperparams: list[str] = eqx.field(static=True, kw_only=True)
96
+ hypernet_input_size: int = eqx.field(kw_only=True)
97
+
98
+ hyper_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
99
+ mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
100
+
101
+ params_hyper: PyTree = eqx.field(init=False)
102
+ static_hyper: PyTree = eqx.field(init=False, static=True)
103
+ pinn_params_sum: Int[onp.ndarray, "1"] = eqx.field(init=False, static=True)
104
+ pinn_params_cumsum: Int[onp.ndarray, "n_layers"] = eqx.field(
105
+ init=False, static=True
106
+ )
107
+
108
+ def __post_init__(self, mlp, hyper_mlp):
109
+ super().__post_init__(
58
110
  mlp,
59
- slice_solution,
60
- eq_type,
61
- input_transform,
62
- output_transform,
63
- output_slice,
64
111
  )
65
112
  self.params_hyper, self.static_hyper = eqx.partition(
66
113
  hyper_mlp, eqx.is_inexact_array
67
114
  )
68
- self.hyperparams = hyperparams
69
- self.hypernet_input_size = hypernet_input_size
70
115
  self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
71
116
 
72
- def init_params(self):
117
+ def init_params(self) -> Params:
118
+ """
119
+ Returns an initial set of parameters
120
+ """
73
121
  return self.params_hyper
74
122
 
75
- def hyper_to_pinn(self, hyper_output):
123
+ def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
76
124
  """
77
125
  From the output of the hypernetwork we set the well formed
78
126
  parameters of the pinn (`self.params`)
79
127
  """
80
128
  pinn_params_flat = eqx.tree_at(
81
- lambda p: tree_leaves(p, is_leaf=lambda x: isinstance(x, jnp.ndarray)),
129
+ lambda p: tree_leaves(p, is_leaf=eqx.is_array),
82
130
  self.params,
83
- [hyper_output[0 : self.pinn_params_cumsum[0]]]
84
- + [
85
- hyper_output[
86
- self.pinn_params_cumsum[i] : self.pinn_params_cumsum[i + 1]
87
- ]
88
- for i in range(len(self.pinn_params_cumsum) - 1)
89
- ],
131
+ jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
90
132
  )
91
133
 
92
134
  return tree_map(
@@ -96,26 +138,31 @@ class HYPERPINN(PINN):
96
138
  is_leaf=lambda x: isinstance(x, jnp.ndarray),
97
139
  )
98
140
 
99
- def _eval_nn(self, inputs, params, input_transform, output_transform):
141
+ def eval_nn(
142
+ self,
143
+ inputs: Float[Array, "input_dim"],
144
+ params: Params | PyTree,
145
+ ) -> Float[Array, "output_dim"]:
100
146
  """
101
- inner function to factorize code. apply_fn (which takes varying forms)
102
- call _eval_nn which always have the same content.
147
+ Evaluate the HYPERPINN on some inputs with some params.
103
148
  """
104
149
  try:
105
- hyper = eqx.combine(params["nn_params"], self.static_hyper)
106
- except (KeyError, TypeError) as e: # give more flexibility
150
+ hyper = eqx.combine(params.nn_params, self.static_hyper)
151
+ except (KeyError, AttributeError, TypeError) as e: # give more flexibility
107
152
  hyper = eqx.combine(params, self.static_hyper)
108
153
 
109
154
  eq_params_batch = jnp.concatenate(
110
- [params["eq_params"][k].flatten() for k in self.hyperparams], axis=0
155
+ [params.eq_params[k].flatten() for k in self.hyperparams], axis=0
111
156
  )
112
157
 
113
158
  hyper_output = hyper(eq_params_batch)
114
159
 
115
- pinn_params = self.hyper_to_pinn(hyper_output)
160
+ pinn_params = self._hyper_to_pinn(hyper_output)
116
161
 
117
162
  pinn = eqx.combine(pinn_params, self.static)
118
- res = output_transform(inputs, pinn(input_transform(inputs, params)).squeeze())
163
+ res = self.output_transform(
164
+ inputs, pinn(self.input_transform(inputs, params)).squeeze(), params
165
+ )
119
166
 
120
167
  if self.output_slice is not None:
121
168
  res = res[self.output_slice]
@@ -127,18 +174,23 @@ class HYPERPINN(PINN):
127
174
 
128
175
 
129
176
  def create_HYPERPINN(
130
- key,
131
- eqx_list,
132
- eq_type,
133
- hyperparams,
134
- hypernet_input_size,
135
- dim_x=0,
136
- input_transform=None,
137
- output_transform=None,
138
- slice_solution=None,
139
- shared_pinn_outputs=None,
140
- eqx_list_hyper=None,
141
- ):
177
+ key: Key,
178
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
179
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
180
+ hyperparams: list[str],
181
+ hypernet_input_size: int,
182
+ dim_x: int = 0,
183
+ input_transform: Callable[
184
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
185
+ ] = None,
186
+ output_transform: Callable[
187
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
188
+ Float[Array, "output_dim"],
189
+ ] = None,
190
+ slice_solution: slice = None,
191
+ shared_pinn_outputs: slice = None,
192
+ eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
193
+ ) -> HYPERPINN | list[HYPERPINN]:
142
194
  r"""
143
195
  Utility function to create a standard PINN neural network with the equinox
144
196
  library.
@@ -146,59 +198,61 @@ def create_HYPERPINN(
146
198
  Parameters
147
199
  ----------
148
200
  key
149
- A jax random key that will be used to initialize the network parameters
201
+ A JAX random key that will be used to initialize the network
202
+ parameters.
150
203
  eqx_list
151
- A list of list of successive equinox modules and activation functions to
152
- describe the PINN architecture. The inner lists have the eqx module or
153
- axtivation function as first item, other items represents arguments
204
+ A tuple of tuples of successive equinox modules and activation functions to
205
+ describe the PINN architecture. The inner tuples must have the eqx module or
206
+ activation function as first item, other items represent arguments
154
207
  that could be required (eg. the size of the layer).
155
- __Note:__ the `key` argument need not be given.
208
+ The `key` argument need not be given.
156
209
  Thus typical example is `eqx_list=
157
- [[eqx.nn.Linear, 2, 20],
158
- [jax.nn.tanh],
159
- [eqx.nn.Linear, 20, 20],
160
- [jax.nn.tanh],
161
- [eqx.nn.Linear, 20, 20],
162
- [jax.nn.tanh],
163
- [eqx.nn.Linear, 20, 1]
164
- ]`
210
+ ((eqx.nn.Linear, 2, 20),
211
+ jax.nn.tanh,
212
+ (eqx.nn.Linear, 20, 20),
213
+ jax.nn.tanh,
214
+ (eqx.nn.Linear, 20, 20),
215
+ jax.nn.tanh,
216
+ (eqx.nn.Linear, 20, 1)
217
+ )`.
165
218
  eq_type
166
219
  A string with three possibilities.
167
- "ODE": the PINN is called with one input `t`.
168
- "statio_PDE": the PINN is called with one input `x`, `x`
220
+ "ODE": the HYPERPINN is called with one input `t`.
221
+ "statio_PDE": the HYPERPINN is called with one input `x`, `x`
169
222
  can be high dimensional.
170
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
223
+ "nonstatio_PDE": the HYPERPINN is called with two inputs `t` and `x`, `x`
171
224
  can be high dimensional.
172
225
  **Note**: the input dimension as given in eqx_list has to match the sum
173
226
  of the dimension of `t` + the dimension of `x` or the output dimension
174
227
  after the `input_transform` function
175
228
  hyperparams
176
- A list of keys from params["eq_params"] that will be considered as
177
- hyperparameters for metamodeling
229
+ A list of keys from Params.eq_params that will be considered as
230
+ hyperparameters for metamodeling.
178
231
  hypernet_input_size
179
232
  An integer. The input size of the MLP used for the hypernetwork. Must
180
233
  be equal to the flattened concatenations for the array of parameters
181
- designated by the `hyperparams` argument
234
+ designated by the `hyperparams` argument.
182
235
  dim_x
183
- An integer. The dimension of `x`. Default `0`
236
+ An integer. The dimension of `x`. Default `0`.
184
237
  input_transform
185
238
  A function that will be called before entering the PINN. Its output(s)
186
- must match the PINN inputs. Its inputs are the PINN inputs (`t` and/or
187
- `x` concatenated together and the parameters). Default is the No operation
239
+ must match the PINN inputs (except for the parameters).
240
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
241
+ and the parameters. Default is no operation.
188
242
  output_transform
189
- A function with arguments the same input(s) as the PINN AND the PINN
190
- output that will be called after exiting the PINN. Default is the No
191
- operation
243
+ A function with arguments begin the same input as the PINN, the PINN
244
+ output and the parameter. This function will be called after exiting the PINN.
245
+ Default is no operation.
192
246
  slice_solution
193
247
  A jnp.s\_ object which indicates which axis of the PINN output is
194
248
  dedicated to the actual equation solution. Default None
195
249
  means that slice_solution = the whole PINN output. This argument is useful
196
250
  when the PINN is also used to output equation parameters for example
197
251
  Note that it must be a slice and not an integer (a preprocessing of the
198
- user provided argument takes care of it)
252
+ user provided argument takes care of it).
199
253
  shared_pinn_outputs
200
254
  Default is None, for a stantard PINN.
201
- A tuple of jnp.s_[] (slices) to determine the different output for each
255
+ A tuple of jnp.s\_[] (slices) to determine the different output for each
202
256
  network. In this case we return a list of PINNs, one for each output in
203
257
  shared_pinn_outputs. This is useful to create PINNs that share the
204
258
  same network and same parameters; **the user must then use the same
@@ -216,7 +270,11 @@ def create_HYPERPINN(
216
270
 
217
271
  Returns
218
272
  -------
219
- `u`, a :class:`.HyperPINN` object which inherits from `eqx.Module` (hence callable).
273
+ hyperpinn
274
+ A HYPERPINN instance or, when `shared_pinn_ouput` is not None,
275
+ a list of HYPERPINN instances with the same structure is returned,
276
+ only differing by there final slicing of the network output.
277
+
220
278
 
221
279
  Raises
222
280
  ------
@@ -259,53 +317,94 @@ def create_HYPERPINN(
259
317
 
260
318
  if output_transform is None:
261
319
 
262
- def output_transform(_in_pinn, _out_pinn):
320
+ def output_transform(_in_pinn, _out_pinn, _params):
263
321
  return _out_pinn
264
322
 
265
323
  key, subkey = jax.random.split(key, 2)
266
- mlp = _MLP(subkey, eqx_list)
324
+ mlp = _MLP(key=subkey, eqx_list=eqx_list)
267
325
  # quick partitioning to get the params to get the correct number of neurons
268
326
  # for the last layer of hyper network
269
327
  params_mlp, _ = eqx.partition(mlp, eqx.is_inexact_array)
270
328
  pinn_params_sum, _ = _get_param_nb(params_mlp)
271
329
  # the number of parameters for the pinn will be the number of ouputs
272
330
  # for the hyper network
273
- try:
274
- eqx_list_hyper[-1][2] = pinn_params_sum
275
- except IndexError:
276
- eqx_list_hyper[-2][2] = pinn_params_sum
277
- try:
278
- eqx_list_hyper[0][1] = hypernet_input_size
279
- except IndexError:
280
- eqx_list_hyper[1][1] = hypernet_input_size
331
+ if len(eqx_list_hyper[-1]) > 1:
332
+ eqx_list_hyper = eqx_list_hyper[:-1] + (
333
+ (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
334
+ )
335
+ else:
336
+ eqx_list_hyper = (
337
+ eqx_list_hyper[:-2]
338
+ + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
339
+ + eqx_list_hyper[-1]
340
+ )
341
+ if len(eqx_list_hyper[0]) > 1:
342
+ eqx_list_hyper = (
343
+ (
344
+ (eqx_list_hyper[0][0],)
345
+ + (hypernet_input_size,)
346
+ + (eqx_list_hyper[0][2],)
347
+ ),
348
+ ) + eqx_list_hyper[1:]
349
+ else:
350
+ eqx_list_hyper = (
351
+ eqx_list_hyper[0]
352
+ + (
353
+ (
354
+ (eqx_list_hyper[1][0],)
355
+ + (hypernet_input_size,)
356
+ + (eqx_list_hyper[1][2],)
357
+ ),
358
+ )
359
+ + eqx_list_hyper[2:]
360
+ )
281
361
  key, subkey = jax.random.split(key, 2)
282
- hyper_mlp = _MLP(subkey, eqx_list_hyper)
362
+
363
+ with warnings.catch_warnings():
364
+ # TODO check why this warning is raised here and not in the PINN
365
+ # context ?
366
+ warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
367
+ hyper_mlp = _MLP(key=subkey, eqx_list=eqx_list_hyper)
283
368
 
284
369
  if shared_pinn_outputs is not None:
285
370
  hyperpinns = []
286
371
  for output_slice in shared_pinn_outputs:
287
- hyperpinn = HYPERPINN(
288
- mlp,
289
- hyper_mlp,
290
- slice_solution,
291
- eq_type,
292
- input_transform,
293
- output_transform,
294
- hyperparams,
295
- hypernet_input_size,
296
- output_slice,
297
- )
372
+ with warnings.catch_warnings():
373
+ # Catch the equinox warning because we put the number of
374
+ # parameters as static while being jnp.Array. This this time
375
+ # this is correct to do so, because they are used as indices
376
+ # and will never be modified
377
+ warnings.filterwarnings(
378
+ "ignore", message="A JAX array is being set as static!"
379
+ )
380
+ hyperpinn = HYPERPINN(
381
+ mlp=mlp,
382
+ hyper_mlp=hyper_mlp,
383
+ slice_solution=slice_solution,
384
+ eq_type=eq_type,
385
+ input_transform=input_transform,
386
+ output_transform=output_transform,
387
+ hyperparams=hyperparams,
388
+ hypernet_input_size=hypernet_input_size,
389
+ output_slice=output_slice,
390
+ )
298
391
  hyperpinns.append(hyperpinn)
299
392
  return hyperpinns
300
- hyperpinn = HYPERPINN(
301
- mlp,
302
- hyper_mlp,
303
- slice_solution,
304
- eq_type,
305
- input_transform,
306
- output_transform,
307
- hyperparams,
308
- hypernet_input_size,
309
- None,
310
- )
393
+ with warnings.catch_warnings():
394
+ # Catch the equinox warning because we put the number of
395
+ # parameters as static while being jnp.Array. This this time
396
+ # this is correct to do so, because they are used as indices
397
+ # and will never be modified
398
+ warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
399
+ hyperpinn = HYPERPINN(
400
+ mlp=mlp,
401
+ hyper_mlp=hyper_mlp,
402
+ slice_solution=slice_solution,
403
+ eq_type=eq_type,
404
+ input_transform=input_transform,
405
+ output_transform=output_transform,
406
+ hyperparams=hyperparams,
407
+ hypernet_input_size=hypernet_input_size,
408
+ output_slice=None,
409
+ )
311
410
  return hyperpinn