jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/utils/_pinn.py CHANGED
@@ -2,44 +2,53 @@
2
2
  Implements utility function to create PINNs
3
3
  """
4
4
 
5
- from typing import Callable
5
+ from typing import Callable, Literal
6
+ from dataclasses import InitVar
6
7
  import jax
7
8
  import jax.numpy as jnp
8
- from jax.typing import ArrayLike
9
9
  import equinox as eqx
10
10
 
11
+ from jaxtyping import Array, Key, PyTree, Float
12
+
13
+ from jinns.parameters._params import Params
14
+
11
15
 
12
16
  class _MLP(eqx.Module):
13
17
  """
14
18
  Class to construct an equinox module from a key and a eqx_list. To be used
15
- in pair with the function `create_PINN`
19
+ in pair with the function `create_PINN`.
20
+
21
+ Parameters
22
+ ----------
23
+ key : InitVar[Key]
24
+ A jax random key for the layer initializations.
25
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
26
+ A tuple of tuples of successive equinox modules and activation functions to
27
+ describe the PINN architecture. The inner tuples must have the eqx module or
28
+ activation function as first item, other items represents arguments
29
+ that could be required (eg. the size of the layer).
30
+ The `key` argument need not be given.
31
+ Thus typical example is `eqx_list=
32
+ ((eqx.nn.Linear, 2, 20),
33
+ jax.nn.tanh,
34
+ (eqx.nn.Linear, 20, 20),
35
+ jax.nn.tanh,
36
+ (eqx.nn.Linear, 20, 20),
37
+ jax.nn.tanh,
38
+ (eqx.nn.Linear, 20, 1)
39
+ )`.
16
40
  """
17
41
 
18
- layers: list
42
+ key: InitVar[Key] = eqx.field(kw_only=True)
43
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
44
+ kw_only=True
45
+ )
19
46
 
20
- def __init__(self, key, eqx_list):
21
- """
22
- Parameters
23
- ----------
24
- key
25
- A jax random key
26
- eqx_list
27
- A list of list of successive equinox modules and activation functions to
28
- describe the PINN architecture. The inner lists have the eqx module or
29
- axtivation function as first item, other items represents arguments
30
- that could be required (eg. the size of the layer).
31
- __Note:__ the `key` argument need not be given.
32
- Thus typical example is `eqx_list=
33
- [[eqx.nn.Linear, 2, 20],
34
- [jax.nn.tanh],
35
- [eqx.nn.Linear, 20, 20],
36
- [jax.nn.tanh],
37
- [eqx.nn.Linear, 20, 20],
38
- [jax.nn.tanh],
39
- [eqx.nn.Linear, 20, 1]
40
- ]`
41
- """
47
+ # NOTE that the following should NOT be declared as static otherwise the
48
+ # eqx.partition that we use in the PINN module will misbehave
49
+ layers: list[eqx.Module] = eqx.field(init=False)
42
50
 
51
+ def __post_init__(self, key, eqx_list):
43
52
  self.layers = []
44
53
  for l in eqx_list:
45
54
  if len(l) == 1:
@@ -48,75 +57,118 @@ class _MLP(eqx.Module):
48
57
  key, subkey = jax.random.split(key, 2)
49
58
  self.layers.append(l[0](*l[1:], key=subkey))
50
59
 
51
- def __call__(self, t):
60
+ def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
52
61
  for layer in self.layers:
53
62
  t = layer(t)
54
63
  return t
55
64
 
56
65
 
57
66
  class PINN(eqx.Module):
58
- """
59
- Basically a wrapper around the `__call__` function to be able to give a type to
60
- our former `self.u`
61
- The function create_PINN has the role to population the `__call__` function
62
- """
67
+ r"""
68
+ A PINN object, i.e., a neural network compatible with the rest of jinns.
69
+ This is typically created with `create_PINN` which creates iternally a
70
+ `_MLP` object. However, a user could directly creates their PINN using this
71
+ class by passing a eqx.Module (for argument `mlp`)
72
+ that plays the role of the NN and that is
73
+ already instanciated.
63
74
 
64
- slice_solution: ArrayLike
65
- eq_type: str = eqx.field(static=True)
66
- input_transform: Callable = eqx.field(static=True)
67
- output_transform: Callable = eqx.field(static=True)
68
- output_slice: ArrayLike
69
- params: eqx.Module
70
- static: eqx.Module
75
+ Parameters
76
+ ----------
77
+ slice_solution : slice
78
+ A jnp.s\_ object which indicates which axis of the PINN output is
79
+ dedicated to the actual equation solution. Default None
80
+ means that slice_solution = the whole PINN output. This argument is useful
81
+ when the PINN is also used to output equation parameters for example
82
+ Note that it must be a slice and not an integer (a preprocessing of the
83
+ user provided argument takes care of it).
84
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
85
+ A string with three possibilities.
86
+ "ODE": the PINN is called with one input `t`.
87
+ "statio_PDE": the PINN is called with one input `x`, `x`
88
+ can be high dimensional.
89
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
90
+ can be high dimensional.
91
+ **Note**: the input dimension as given in eqx_list has to match the sum
92
+ of the dimension of `t` + the dimension of `x` or the output dimension
93
+ after the `input_transform` function.
94
+ input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
95
+ A function that will be called before entering the PINN. Its output(s)
96
+ must match the PINN inputs (except for the parameters).
97
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
98
+ and the parameters. Default is no operation.
99
+ output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
100
+ A function with arguments begin the same input as the PINN, the PINN
101
+ output and the parameter. This function will be called after exiting the PINN.
102
+ Default is no operation.
103
+ output_slice : slice, default=None
104
+ A jnp.s\_[] to determine the different dimension for the PINN.
105
+ See `shared_pinn_outputs` argument of `create_PINN`.
106
+ mlp : eqx.Module
107
+ The actual neural network instanciated as an eqx.Module.
108
+ """
71
109
 
72
- def __init__(
73
- self,
74
- mlp,
75
- slice_solution,
76
- eq_type,
77
- input_transform,
78
- output_transform,
79
- output_slice,
80
- ):
110
+ slice_solution: slice = eqx.field(static=True, kw_only=True)
111
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
112
+ static=True, kw_only=True
113
+ )
114
+ input_transform: Callable[
115
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
116
+ ] = eqx.field(static=True, kw_only=True)
117
+ output_transform: Callable[
118
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
119
+ Float[Array, "output_dim"],
120
+ ] = eqx.field(static=True, kw_only=True)
121
+ output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
122
+
123
+ mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
124
+
125
+ params: PyTree = eqx.field(init=False)
126
+ static: PyTree = eqx.field(init=False, static=True)
127
+
128
+ def __post_init__(self, mlp):
81
129
  self.params, self.static = eqx.partition(mlp, eqx.is_inexact_array)
82
- self.slice_solution = slice_solution
83
- self.eq_type = eq_type
84
- self.input_transform = input_transform
85
- self.output_transform = output_transform
86
- self.output_slice = output_slice
87
130
 
88
- def init_params(self):
131
+ def init_params(self) -> PyTree:
132
+ """
133
+ Returns an initial set of parameters
134
+ """
89
135
  return self.params
90
136
 
91
- def __call__(self, *args):
137
+ def __call__(self, *args) -> Float[Array, "output_dim"]:
138
+ """
139
+ Calls `eval_nn` with rearranged arguments
140
+ """
92
141
  if self.eq_type == "ODE":
93
142
  (t, params) = args
94
143
  if len(t.shape) == 0:
95
144
  t = t[..., None] # Add mandatory dimension which can be lacking
96
145
  # (eg. for the ODE batches) but this dimension can already
97
146
  # exists (eg. for user provided observation times)
98
- return self._eval_nn(t, params, self.input_transform, self.output_transform)
147
+ return self.eval_nn(t, params)
99
148
  if self.eq_type == "statio_PDE":
100
149
  (x, params) = args
101
- return self._eval_nn(x, params, self.input_transform, self.output_transform)
150
+ return self.eval_nn(x, params)
102
151
  if self.eq_type == "nonstatio_PDE":
103
152
  (t, x, params) = args
104
153
  t_x = jnp.concatenate([t, x], axis=-1)
105
- return self._eval_nn(
106
- t_x, params, self.input_transform, self.output_transform
107
- )
154
+ return self.eval_nn(t_x, params)
108
155
  raise ValueError("Wrong value for self.eq_type")
109
156
 
110
- def _eval_nn(self, inputs, params, input_transform, output_transform):
157
+ def eval_nn(
158
+ self,
159
+ inputs: Float[Array, "input_dim"],
160
+ params: Params | PyTree,
161
+ ) -> Float[Array, "output_dim"]:
111
162
  """
112
- inner function to factorize code. apply_fn (which takes varying forms)
113
- call _eval_nn which always have the same content.
163
+ Evaluate the PINN on some inputs with some params.
114
164
  """
115
165
  try:
116
- model = eqx.combine(params["nn_params"], self.static)
117
- except (KeyError, TypeError) as e: # give more flexibility
166
+ model = eqx.combine(params.nn_params, self.static)
167
+ except (KeyError, AttributeError, TypeError) as e: # give more flexibility
118
168
  model = eqx.combine(params, self.static)
119
- res = output_transform(inputs, model(input_transform(inputs, params)).squeeze())
169
+ res = self.output_transform(
170
+ inputs, model(self.input_transform(inputs, params)).squeeze(), params
171
+ )
120
172
 
121
173
  if self.output_slice is not None:
122
174
  res = res[self.output_slice]
@@ -128,15 +180,20 @@ class PINN(eqx.Module):
128
180
 
129
181
 
130
182
  def create_PINN(
131
- key,
132
- eqx_list,
133
- eq_type,
134
- dim_x=0,
135
- input_transform=None,
136
- output_transform=None,
137
- shared_pinn_outputs=None,
138
- slice_solution=None,
139
- ):
183
+ key: Key,
184
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
185
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
186
+ dim_x: int = 0,
187
+ input_transform: Callable[
188
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
189
+ ] = None,
190
+ output_transform: Callable[
191
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
192
+ Float[Array, "output_dim"],
193
+ ] = None,
194
+ shared_pinn_outputs: tuple[slice] = None,
195
+ slice_solution: slice = None,
196
+ ) -> PINN | list[PINN]:
140
197
  r"""
141
198
  Utility function to create a standard PINN neural network with the equinox
142
199
  library.
@@ -144,22 +201,25 @@ def create_PINN(
144
201
  Parameters
145
202
  ----------
146
203
  key
147
- A jax random key that will be used to initialize the network parameters
204
+ A JAX random key that will be used to initialize the network
205
+ parameters.
148
206
  eqx_list
149
- A list of list of successive equinox modules and activation functions to
150
- describe the PINN architecture. The inner lists have the eqx module or
151
- axtivation function as first item, other items represents arguments
152
- that could be required (eg. the size of the layer).
153
- __Note:__ the `key` argument need not be given.
154
- Thus typical example is `eqx_list=
155
- [[eqx.nn.Linear, 2, 20],
156
- [jax.nn.tanh],
157
- [eqx.nn.Linear, 20, 20],
158
- [jax.nn.tanh],
159
- [eqx.nn.Linear, 20, 20],
160
- [jax.nn.tanh],
161
- [eqx.nn.Linear, 20, 1]
162
- ]`
207
+ A tuple of tuples of successive equinox modules and activation
208
+ functions to describe the PINN architecture. The inner tuples must have
209
+ the eqx module or activation function as first item, other items
210
+ represent arguments that could be required (eg. the size of the layer).
211
+
212
+ The `key` argument do not need to be given.
213
+
214
+ A typical example is `eqx_list = (
215
+ (eqx.nn.Linear, input_dim, 20),
216
+ (jax.nn.tanh,),
217
+ (eqx.nn.Linear, 20, 20),
218
+ (jax.nn.tanh,),
219
+ (eqx.nn.Linear, 20, 20),
220
+ (jax.nn.tanh,),
221
+ (eqx.nn.Linear, 20, output_dim)
222
+ )`.
163
223
  eq_type
164
224
  A string with three possibilities.
165
225
  "ODE": the PINN is called with one input `t`.
@@ -169,17 +229,19 @@ def create_PINN(
169
229
  can be high dimensional.
170
230
  **Note**: the input dimension as given in eqx_list has to match the sum
171
231
  of the dimension of `t` + the dimension of `x` or the output dimension
172
- after the `input_transform` function
232
+ after the `input_transform` function.
173
233
  dim_x
174
- An integer. The dimension of `x`. Default `0`
234
+ An integer. The dimension of `x`. Default `0`.
175
235
  input_transform
176
236
  A function that will be called before entering the PINN. Its output(s)
177
- must match the PINN inputs. Its inputs are the PINN inputs (`t` and/or
178
- `x` concatenated together and the parameters). Default is the No operation
237
+ must match the PINN inputs (except for the parameters).
238
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
239
+ and the parameters. Default is no operation.
179
240
  output_transform
180
- A function with arguments the same input(s) as the PINN AND the PINN
181
- output that will be called after exiting the PINN. Default is the No
182
- operation
241
+ A function with arguments begin the same input as the PINN, the PINN
242
+ output and the parameter. This function will be called after exiting
243
+ the PINN.
244
+ Default is no operation.
183
245
  shared_pinn_outputs
184
246
  Default is None, for a stantard PINN.
185
247
  A tuple of jnp.s\_[] (slices) to determine the different output for each
@@ -192,15 +254,18 @@ def create_PINN(
192
254
  slice_solution
193
255
  A jnp.s\_ object which indicates which axis of the PINN output is
194
256
  dedicated to the actual equation solution. Default None
195
- means that slice_solution = the whole PINN output. This argument is useful
196
- when the PINN is also used to output equation parameters for example
197
- Note that it must be a slice and not an integer (a preprocessing of the
198
- user provided argument takes care of it)
257
+ means that slice_solution = the whole PINN output. This argument is
258
+ useful when the PINN is also used to output equation parameters for
259
+ example Note that it must be a slice and not an integer (a
260
+ preprocessing of the user provided argument takes care of it).
199
261
 
200
262
 
201
263
  Returns
202
264
  -------
203
- `u`, a :class:`.PINN` object which inherits from `eqx.Module` (hence callable). This comes with a bound method :func:`u.init_params() <PINN.init_params>`. When `shared_pinn_ouput` is not None, a list of :class:`.PINN` with the same structure is returned, only differing by there final slicing of the network output.
265
+ pinn
266
+ A PINN instance or, when `shared_pinn_ouput` is not None,
267
+ a list of PINN instances with the same structure is returned,
268
+ only differing by there final slicing of the network output.
204
269
 
205
270
  Raises
206
271
  ------
@@ -240,23 +305,30 @@ def create_PINN(
240
305
 
241
306
  if output_transform is None:
242
307
 
243
- def output_transform(_in_pinn, _out_pinn):
308
+ def output_transform(_in_pinn, _out_pinn, _params):
244
309
  return _out_pinn
245
310
 
246
- mlp = _MLP(key, eqx_list)
311
+ mlp = _MLP(key=key, eqx_list=eqx_list)
247
312
 
248
313
  if shared_pinn_outputs is not None:
249
314
  pinns = []
250
315
  for output_slice in shared_pinn_outputs:
251
316
  pinn = PINN(
252
- mlp,
253
- slice_solution,
254
- eq_type,
255
- input_transform,
256
- output_transform,
257
- output_slice,
317
+ mlp=mlp,
318
+ slice_solution=slice_solution,
319
+ eq_type=eq_type,
320
+ input_transform=input_transform,
321
+ output_transform=output_transform,
322
+ output_slice=output_slice,
258
323
  )
259
324
  pinns.append(pinn)
260
325
  return pinns
261
- pinn = PINN(mlp, slice_solution, eq_type, input_transform, output_transform, None)
326
+ pinn = PINN(
327
+ mlp=mlp,
328
+ slice_solution=slice_solution,
329
+ eq_type=eq_type,
330
+ input_transform=input_transform,
331
+ output_transform=output_transform,
332
+ output_slice=None,
333
+ )
262
334
  return pinn
jinns/utils/_save_load.py CHANGED
@@ -2,58 +2,64 @@
2
2
  Implements save and load functions
3
3
  """
4
4
 
5
+ from typing import Callable, Literal
5
6
  import pickle
6
7
  import jax
7
8
  import equinox as eqx
8
9
 
9
- from jinns.utils._pinn import create_PINN
10
- from jinns.utils._spinn import create_SPINN
11
- from jinns.utils._hyperpinn import create_HYPERPINN
10
+ from jinns.utils._pinn import create_PINN, PINN
11
+ from jinns.utils._spinn import create_SPINN, SPINN
12
+ from jinns.utils._hyperpinn import create_HYPERPINN, HYPERPINN
13
+ from jinns.parameters._params import Params, ParamsDict
12
14
 
13
15
 
14
- def function_to_string(eqx_list):
16
+ def function_to_string(
17
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...]
18
+ ) -> tuple[tuple[str, int, int] | str, ...]:
15
19
  """
16
20
  We need this transformation for eqx_list to be pickled
17
21
 
18
- From `[[eqx.nn.Linear, 2, 20],
19
- [jax.nn.tanh],
20
- [eqx.nn.Linear, 20, 20],
21
- [jax.nn.tanh],
22
- [eqx.nn.Linear, 20, 20],
23
- [jax.nn.tanh],
24
- [eqx.nn.Linear, 20, 1]` to
25
- `[["Linear", 2, 20],
26
- ["tanh"],
27
- ["Linear", 20, 20],
28
- ["tanh"],
29
- ["Linear", 20, 20],
30
- ["tanh"],
31
- ["Linear", 20, 1]`
22
+ From `((eqx.nn.Linear, 2, 20),
23
+ (jax.nn.tanh),
24
+ (eqx.nn.Linear, 20, 20),
25
+ (jax.nn.tanh),
26
+ (eqx.nn.Linear, 20, 20),
27
+ (jax.nn.tanh),
28
+ (eqx.nn.Linear, 20, 1))` to
29
+ `(("Linear", 2, 20),
30
+ ("tanh"),
31
+ ("Linear", 20, 20),
32
+ ("tanh"),
33
+ ("Linear", 20, 20),
34
+ ("tanh"),
35
+ ("Linear", 20, 1))`
32
36
  """
33
37
  return jax.tree_util.tree_map(
34
38
  lambda x: x.__name__ if hasattr(x, "__call__") else x, eqx_list
35
39
  )
36
40
 
37
41
 
38
- def string_to_function(eqx_list_with_string):
42
+ def string_to_function(
43
+ eqx_list_with_string: tuple[tuple[str, int, int] | str, ...]
44
+ ) -> tuple[tuple[Callable, int, int] | Callable, ...]:
39
45
  """
40
46
  We need this transformation for eqx_list at the loading ("unpickling")
41
47
  operation.
42
48
 
43
- From `[["Linear", 2, 20],
44
- ["tanh"],
45
- ["Linear", 20, 20],
46
- ["tanh"],
47
- ["Linear", 20, 20],
48
- ["tanh"],
49
- ["Linear", 20, 1]`
50
- to `[[eqx.nn.Linear, 2, 20],
51
- [jax.nn.tanh],
52
- [eqx.nn.Linear, 20, 20],
53
- [jax.nn.tanh],
54
- [eqx.nn.Linear, 20, 20],
55
- [jax.nn.tanh],
56
- [eqx.nn.Linear, 20, 1]` to
49
+ From `(("Linear", 2, 20),
50
+ ("tanh"),
51
+ ("Linear", 20, 20),
52
+ ("tanh"),
53
+ ("Linear", 20, 20),
54
+ ("tanh"),
55
+ ("Linear", 20, 1))`
56
+ to `((eqx.nn.Linear, 2, 20),
57
+ (jax.nn.tanh),
58
+ (eqx.nn.Linear, 20, 20),
59
+ (jax.nn.tanh),
60
+ (eqx.nn.Linear, 20, 20),
61
+ (jax.nn.tanh),
62
+ (eqx.nn.Linear, 20, 1))`
57
63
  """
58
64
 
59
65
  def _str_to_fun(l):
@@ -76,16 +82,36 @@ def string_to_function(eqx_list_with_string):
76
82
  )
77
83
 
78
84
 
79
- def save_pinn(filename, u, params, kwargs_creation):
85
+ def save_pinn(
86
+ filename: str,
87
+ u: PINN | HYPERPINN | SPINN,
88
+ params: Params | ParamsDict,
89
+ kwargs_creation,
90
+ ):
80
91
  """
81
92
  Save a PINN / HyperPINN / SPINN model
82
93
  This function creates 3 files, beggining by `filename`
83
94
 
84
95
  1. an eqx file to save the eqx.Module (the PINN, HyperPINN, ...)
85
- 2. a pickle file for the parameters
86
- 3. a pickle file for the arguments that have been used at PINN
87
-
88
- creation and that we need to reconstruct the eqx.module later on.
96
+ 2. a pickle file for the parameters of the equation
97
+ 3. a pickle file for the arguments that have been used at PINN creation
98
+ and that we need to reconstruct the eqx.module later on.
99
+
100
+ Note that the equation parameters `Params.eq_params` go in the
101
+ pickle file while the neural network parameters `Params.nn_params` go in
102
+ the `"*-module.eqx"` file (normal behaviour with `eqx.
103
+ tree_serialise_leaves`).
104
+
105
+ Equation parameters are saved apart because the initial type of attribute
106
+ `params` in PINN / HYPERPINN / SPINN is not `Params` nor `ParamsDict`
107
+ but `PyTree` as inherited from `eqx.partition`.
108
+ Therefore, if we want to ensure a proper serialization/deserialization:
109
+ - we cannot save a `Params` object at this
110
+ attribute field ; the `Params` object must be split into `Params.nn_params`
111
+ (type `PyTree`) and `Params.eq_params` (type `dict`).
112
+ - in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
113
+ the attribute field `params` because it is not a `PyTree` (as expected in
114
+ the PINN / HYPERPINN / SPINN signature) but it is still a dictionary.
89
115
 
90
116
  Parameters
91
117
  ----------
@@ -94,17 +120,32 @@ def save_pinn(filename, u, params, kwargs_creation):
94
120
  u
95
121
  The PINN
96
122
  params
97
- The dictionary of parameters of the model.
98
- Typically, it is a dictionary of
99
- dictionaries: `eq_params` and `nn_params`, respectively the
100
- differential equation parameters and the neural network parameter
123
+ Params or ParamsDict to be save
101
124
  kwargs_creation
102
125
  The dictionary of arguments that were used to create the PINN, e.g.
103
126
  the layers list, O/PDE type, etc.
104
127
  """
105
- eqx.tree_serialise_leaves(filename + "-module.eqx", u)
106
- with open(filename + "-parameters.pkl", "wb") as f:
107
- pickle.dump(params, f)
128
+ if isinstance(params, Params):
129
+ if isinstance(u, HYPERPINN):
130
+ u = eqx.tree_at(lambda m: m.params_hyper, u, params)
131
+ elif isinstance(u, (PINN, SPINN)):
132
+ u = eqx.tree_at(lambda m: m.params, u, params)
133
+ eqx.tree_serialise_leaves(filename + "-module.eqx", u)
134
+
135
+ elif isinstance(params, ParamsDict):
136
+ for key, params_ in params.nn_params.items():
137
+ if isinstance(u, HYPERPINN):
138
+ u = eqx.tree_at(lambda m: m.params_hyper, u, params_)
139
+ elif isinstance(u, (PINN, SPINN)):
140
+ u = eqx.tree_at(lambda m: m.params, u, params_)
141
+ eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
142
+
143
+ else:
144
+ raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
145
+
146
+ with open(filename + "-eq_params.pkl", "wb") as f:
147
+ pickle.dump(params.eq_params, f)
148
+
108
149
  kwargs_creation = kwargs_creation.copy() # avoid side-effect that would be
109
150
  # very probably harmless anyway
110
151
 
@@ -124,7 +165,11 @@ def save_pinn(filename, u, params, kwargs_creation):
124
165
  pickle.dump(kwargs_creation, f)
125
166
 
126
167
 
127
- def load_pinn(filename, type_):
168
+ def load_pinn(
169
+ filename: str,
170
+ type_: Literal["pinn", "hyperpinn", "spinn"],
171
+ key_list_for_paramsdict: list[str] = None,
172
+ ) -> tuple[eqx.Module, Params | ParamsDict]:
128
173
  """
129
174
  Load a PINN model. This function needs to access 3 files :
130
175
  `{filename}-module.eqx`, `{filename}-parameters.pkl` and
@@ -132,27 +177,35 @@ def load_pinn(filename, type_):
132
177
 
133
178
  These files are created by `jinns.utils.save_pinn`.
134
179
 
135
- Note that this requires equinox v0.11.3 (currently latest version) for the
180
+ Note that this requires equinox>v0.11.3 for the
136
181
  `eqx.filter_eval_shape` to work.
137
182
 
183
+ See note in `save_pinn` for more details about the saving process
184
+
138
185
  Parameters
139
186
  ----------
140
187
  filename
141
188
  Filename (prefix) without extension.
142
189
  type_
143
190
  Type of model to load. Must be in ["pinn", "hyperpinn", "spinn"].
191
+ key_list_for_paramsdict
192
+ Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
144
193
 
145
194
  Returns
146
195
  -------
147
196
  u_reloaded
148
197
  The reloaded PINN
149
- params_reloaded
198
+ params
150
199
  The reloaded parameters
151
200
  """
152
201
  with open(filename + "-arguments.pkl", "rb") as f:
153
202
  kwargs_reloaded = pickle.load(f)
154
- with open(filename + "-parameters.pkl", "rb") as f:
155
- params_reloaded = pickle.load(f)
203
+ try:
204
+ with open(filename + "-eq_params.pkl", "rb") as f:
205
+ eq_params_reloaded = pickle.load(f)
206
+ except FileNotFoundError:
207
+ eq_params_reloaded = {}
208
+ print("No pickle file for equation parameters found!")
156
209
  kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
157
210
  if type_ == "pinn":
158
211
  # next line creates a shallow model, the jax arrays are just shapes and
@@ -167,9 +220,21 @@ def load_pinn(filename, type_):
167
220
  u_reloaded_shallow = eqx.filter_eval_shape(create_HYPERPINN, **kwargs_reloaded)
168
221
  else:
169
222
  raise ValueError(f"{type_} is not valid")
170
- # now the empty structure is populated with the actual saved array values
171
- # stored in the eqx file
172
- u_reloaded = eqx.tree_deserialise_leaves(
173
- filename + "-module.eqx", u_reloaded_shallow
174
- )
175
- return u_reloaded, params_reloaded
223
+ if key_list_for_paramsdict is None:
224
+ # now the empty structure is populated with the actual saved array values
225
+ # stored in the eqx file
226
+ u_reloaded = eqx.tree_deserialise_leaves(
227
+ filename + "-module.eqx", u_reloaded_shallow
228
+ )
229
+ params = Params(
230
+ nn_params=u_reloaded.init_params(), eq_params=eq_params_reloaded
231
+ )
232
+ else:
233
+ nn_params_dict = {}
234
+ for key in key_list_for_paramsdict:
235
+ u_reloaded = eqx.tree_deserialise_leaves(
236
+ filename + f"-module_{key}.eqx", u_reloaded_shallow
237
+ )
238
+ nn_params_dict[key] = u_reloaded.init_params()
239
+ params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
240
+ return u_reloaded, params