jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_pinn.py ADDED
@@ -0,0 +1,204 @@
1
+ """
2
+ Implement abstract class for PINN architectures
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Union, Any, Literal, overload
8
+ from dataclasses import InitVar
9
+ import equinox as eqx
10
+ from jaxtyping import Float, Array, PyTree
11
+ import jax.numpy as jnp
12
+ from jinns.parameters._params import Params
13
+ from jinns.nn._abstract_pinn import AbstractPINN
14
+ from jinns.nn._utils import _PyTree_to_Params
15
+
16
+
17
+ class PINN(AbstractPINN):
18
+ r"""
19
+ Base class for PINN objects. It can be seen as a wrapper on
20
+ an `eqx.Module` which actually implement the NN architectures, with extra
21
+ arguments handling the "physics-informed" aspect.
22
+
23
+ !!! Note
24
+ We use the `eqx.partition` and `eqx.combine` strategy of Equinox: a
25
+ `filter_spec` is applied on the PyTree and splits it into two PyTree with
26
+ the same structure: a static one (invisible to JAX transform such as JIT,
27
+ grad, etc.) and dynamic one. By convention, anything not static is
28
+ considered a parameter in Jinns.
29
+
30
+ For compatibility with jinns, we require that a `PINN` architecture:
31
+
32
+ 1) has an eqx.Module (`eqx_network`) InitVar passed to __post_init__
33
+ representing the network architecture.
34
+ 2) calls `eqx.partition` in __post_init__ in order to store the
35
+ static part of the model and the initial parameters.
36
+ 3) has a `eq_type` argument, used for handling internal operations in
37
+ jinns.
38
+ 4) has a `slice_solution` argument. It is a `jnp.s\_` object which
39
+ indicates which axis of the PINN output is dedicated to the actual equation
40
+ solution. Default None means that slice_solution = the whole PINN output.
41
+ For example, this argument is useful when the PINN is also used to output
42
+ equation parameters. Note that it must be a slice and not an integer (a
43
+ preprocessing of the user provided argument takes care of it).
44
+
45
+ Parameters
46
+ ----------
47
+ slice_solution : slice
48
+ Default is jnp.s\_[...]. A jnp.s\_ object which indicates which axis of the PINN output is
49
+ dedicated to the actual equation solution. Default None
50
+ means that slice_solution = the whole PINN output. This argument is useful
51
+ when the PINN is also used to output equation parameters for example
52
+ Note that it must be a slice and not an integer (a preprocessing of the
53
+ user provided argument takes care of it).
54
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
55
+ A string with three possibilities.
56
+ "ODE": the PINN is called with one input `t`.
57
+ "statio_PDE": the PINN is called with one input `x`, `x`
58
+ can be high dimensional.
59
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
60
+ can be high dimensional.
61
+ **Note**: the input dimension as given in eqx_list has to match the sum
62
+ of the dimension of `t` + the dimension of `x` or the output dimension
63
+ after the `input_transform` function.
64
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
65
+ A function that will be called before entering the PINN. Its output(s)
66
+ must match the PINN inputs (except for the parameters).
67
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
68
+ and the parameters. Default is no operation.
69
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
70
+ A function with arguments begin the same input as the PINN, the PINN
71
+ output and the parameter. This function will be called after exiting the PINN.
72
+ Default is no operation.
73
+ eqx_network : eqx.Module
74
+ The actual neural network instanciated as an eqx.Module.
75
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
76
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
77
+ a trainable parameter. Quoting from equinox documentation:
78
+ a PyTree whose structure should be a prefix of the structure of pytree.
79
+ Each of its leaves should either be 1) True, in which case the leaf or
80
+ subtree is kept; 2) False, in which case the leaf or subtree is
81
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
82
+
83
+
84
+ Raises
85
+ ------
86
+ RuntimeError
87
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
88
+ "nonstatio_PDE"]`
89
+ """
90
+
91
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
92
+ static=True, kw_only=True
93
+ )
94
+ slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
95
+ input_transform: Callable[
96
+ [Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]
97
+ ] = eqx.field(static=True, kw_only=True, default=None)
98
+ output_transform: Callable[
99
+ [Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]],
100
+ Float[Array, " output_dim"],
101
+ ] = eqx.field(static=True, kw_only=True, default=None)
102
+
103
+ eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
104
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
105
+ static=True, kw_only=True, default=eqx.is_inexact_array
106
+ )
107
+
108
+ init_params: PINN = eqx.field(init=False)
109
+ static: PINN = eqx.field(init=False, static=True)
110
+
111
+ def __post_init__(self, eqx_network):
112
+ if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
113
+ raise RuntimeError("Wrong parameter value for eq_type")
114
+ # saving the static part of the model and initial parameters
115
+
116
+ if self.filter_spec is None:
117
+ self.filter_spec = eqx.is_inexact_array
118
+
119
+ self.init_params, self.static = eqx.partition(eqx_network, self.filter_spec)
120
+
121
+ if self.input_transform is None:
122
+ self.input_transform = lambda _in, _params: _in
123
+
124
+ if self.output_transform is None:
125
+ self.output_transform = lambda _in_pinn, _out_pinn, _params: _out_pinn
126
+
127
+ if self.slice_solution is None:
128
+ self.slice_solution = jnp.s_[:]
129
+
130
+ if isinstance(self.slice_solution, int):
131
+ # rewrite it as a slice to ensure that axis does not disappear when
132
+ # indexing
133
+ self.slice_solution = jnp.s_[self.slice_solution : self.slice_solution + 1]
134
+
135
+ def eval(self, network, inputs, *args, **kwargs):
136
+ """How to call your Equinox module `network`. The purpose of this method
137
+ is to give more flexibility : user should re-implement `eval`
138
+ when inheriting from `PINN` if they desire more flexibility on how to
139
+ evaluate the network.
140
+
141
+ Defaults to using `network.__call__(inputs)` but it could be more refined *e.g.* `network.anymethod(inputs)`.
142
+
143
+ Parameters
144
+ ----------
145
+ network : eqx.Module
146
+ Your neural network with the parameters set, usually returned by
147
+ `eqx.combine(self.static, current_params)`.
148
+ inputs : Array
149
+ The inputs, evetually transformed by `self.input_transformed` if
150
+ specified by the user.
151
+
152
+ Returns
153
+ -------
154
+ Array
155
+ The output
156
+ """
157
+
158
+ return network(inputs)
159
+
160
+ @overload
161
+ @_PyTree_to_Params
162
+ def __call__(
163
+ self,
164
+ inputs: Float[Array, " input_dim"],
165
+ params: PyTree,
166
+ *args,
167
+ **kwargs,
168
+ ) -> Float[Array, " output_dim"]: ...
169
+
170
+ @_PyTree_to_Params
171
+ def __call__(
172
+ self,
173
+ inputs: Float[Array, " input_dim"],
174
+ params: Params[Array],
175
+ *args,
176
+ **kwargs,
177
+ ) -> Float[Array, " output_dim"]:
178
+ """
179
+ A proper __call__ implementation performs an eqx.combine here with
180
+ `params` and `self.static` to recreate the callable eqx.Module
181
+ architecture. The rest of the content of this function is dependent on
182
+ the network.
183
+
184
+ Note that that thanks to the decorator, params can also directly be the
185
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
186
+ """
187
+
188
+ if len(inputs.shape) == 0:
189
+ # This can happen often when the user directly provides some
190
+ # collocation points (eg for plotting, whithout using
191
+ # DataGenerators)
192
+ inputs = inputs[None]
193
+
194
+ model = eqx.combine(params.nn_params, self.static)
195
+
196
+ # evaluate the model
197
+ res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)
198
+
199
+ res = self.output_transform(inputs, res.squeeze(), params)
200
+
201
+ # force (1,) output for non vectorial solution (consistency)
202
+ if not res.shape:
203
+ return jnp.expand_dims(res, axis=-1)
204
+ return res
jinns/nn/_ppinn.py ADDED
@@ -0,0 +1,239 @@
1
+ """
2
+ Implements utility function to create PINNs
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Literal, Self, cast, overload
8
+ from dataclasses import InitVar
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import equinox as eqx
12
+
13
+ from jaxtyping import Array, Key, Float, PyTree
14
+
15
+ from jinns.parameters._params import Params
16
+ from jinns.nn._pinn import PINN
17
+ from jinns.nn._mlp import MLP
18
+ from jinns.nn._utils import _PyTree_to_Params
19
+
20
+
21
+ class PPINN_MLP(PINN):
22
+ r"""
23
+ A PPINN MLP (Parallel PINN with MLPs) object which mimicks the PFNN architecture from
24
+ DeepXDE. This is in fact a PINN MLP that encompasses several PINN MLPs internally.
25
+
26
+ Parameters
27
+ ----------
28
+ slice_solution : slice
29
+ A jnp.s\_ object which indicates which axis of the PPINN output is
30
+ dedicated to the actual equation solution. Default None
31
+ means that slice_solution = the whole PPINN output. This argument is useful
32
+ when the PINN is also used to output equation parameters for example
33
+ Note that it must be a slice and not an integer (a preprocessing of the
34
+ user provided argument takes care of it).
35
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
36
+ A string with three possibilities.
37
+ "ODE": the PPINN is called with one input `t`.
38
+ "statio_PDE": the PPINN is called with one input `x`, `x`
39
+ can be high dimensional.
40
+ "nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
41
+ can be high dimensional.
42
+ **Note**: the input dimension as given in eqx_list has to match the sum
43
+ of the dimension of `t` + the dimension of `x` or the output dimension
44
+ after the `input_transform` function.
45
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
46
+ A function that will be called before entering the PPINN. Its output(s)
47
+ must match the PPINN inputs (except for the parameters).
48
+ Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
49
+ and the parameters. Default is no operation.
50
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
51
+ A function with arguments begin the same input as the PPINN, the PPINN
52
+ output and the parameter. This function will be called after exiting
53
+ the PPINN.
54
+ Default is no operation.
55
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
56
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
57
+ a trainable parameter. Quoting from equinox documentation:
58
+ a PyTree whose structure should be a prefix of the structure of pytree.
59
+ Each of its leaves should either be 1) True, in which case the leaf or
60
+ subtree is kept; 2) False, in which case the leaf or subtree is
61
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
62
+ eqx_network_list
63
+ A list of eqx.nn.MLP objects with same input
64
+ dimensions. They represent the parallel subnetworks of the PPIN MLP.
65
+ Their respective outputs are concatenated.
66
+ """
67
+
68
+ eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
69
+ init_params: tuple[PINN, ...] = eqx.field(
70
+ init=False
71
+ ) # overriding parent attribute type
72
+ static: tuple[PINN, ...] = eqx.field(
73
+ init=False, static=True
74
+ ) # overriding parent attribute type
75
+
76
+ def __post_init__(self, eqx_network, eqx_network_list):
77
+ super().__post_init__(
78
+ eqx_network=eqx_network_list[0], # this is not used since it is
79
+ # overwritten just below
80
+ )
81
+ params, static = eqx.partition(eqx_network_list[0], self.filter_spec)
82
+ self.init_params, self.static = (params,), (static,)
83
+ for eqx_network_ in eqx_network_list[1:]:
84
+ params, static = eqx.partition(eqx_network_, self.filter_spec)
85
+ self.init_params = self.init_params + (params,)
86
+ self.static = self.static + (static,)
87
+
88
+ @overload
89
+ @_PyTree_to_Params
90
+ def __call__(
91
+ self,
92
+ inputs: Float[Array, " input_dim"],
93
+ params: PyTree,
94
+ *args,
95
+ **kwargs,
96
+ ) -> Float[Array, " output_dim"]: ...
97
+
98
+ @_PyTree_to_Params
99
+ def __call__(
100
+ self,
101
+ inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
102
+ params: Params[Array],
103
+ ) -> Float[Array, " output_dim"]:
104
+ """
105
+ Evaluate the PPINN on some inputs with some params.
106
+
107
+ Note that that thanks to the decorator, params can also directly be the
108
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
109
+ """
110
+ if len(inputs.shape) == 0:
111
+ # This can happen often when the user directly provides some
112
+ # collocation points (eg for plotting, whithout using
113
+ # DataGenerators)
114
+ inputs = inputs[None]
115
+ transformed_inputs = self.input_transform(inputs, params)
116
+
117
+ outs = []
118
+
119
+ # try:
120
+ for params_, static in zip(params.nn_params, self.static):
121
+ model = eqx.combine(params_, static)
122
+ outs += [model(transformed_inputs)] # type: ignore
123
+ # except (KeyError, AttributeError, TypeError) as e:
124
+ # for params_, static in zip(params, self.static):
125
+ # model = eqx.combine(params_, static)
126
+ # outs += [model(transformed_inputs)]
127
+ # Note that below is then a global output transform
128
+ res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
129
+
130
+ ## force (1,) output for non vectorial solution (consistency)
131
+ if not res.shape:
132
+ return jnp.expand_dims(res, axis=-1)
133
+ return res
134
+
135
+ @classmethod
136
+ def create(
137
+ cls,
138
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
139
+ eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
140
+ key: Key = None,
141
+ eqx_list_list: (
142
+ list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
143
+ ) = None,
144
+ input_transform: (
145
+ Callable[
146
+ [Float[Array, " input_dim"], Params[Array]],
147
+ Float[Array, " output_dim"],
148
+ ]
149
+ | None
150
+ ) = None,
151
+ output_transform: (
152
+ Callable[
153
+ [
154
+ Float[Array, " input_dim"],
155
+ Float[Array, " output_dim"],
156
+ Params[Array],
157
+ ],
158
+ Float[Array, " output_dim"],
159
+ ]
160
+ | None
161
+ ) = None,
162
+ slice_solution: slice | None = None,
163
+ ) -> tuple[Self, tuple[PINN, ...]]:
164
+ r"""
165
+ Utility function to create a Parrallel PINN neural network for Jinns.
166
+
167
+ Parameters
168
+ ----------
169
+ eq_type
170
+ A string with three possibilities.
171
+ "ODE": the PPINN MLP is called with one input `t`.
172
+ "statio_PDE": the PPINN MLP is called with one input `x`, `x`
173
+ can be high dimensional.
174
+ "nonstatio_PDE": the PPINN MLP is called with two inputs `t` and `x`, `x`
175
+ can be high dimensional.
176
+ **Note**: the input dimension as given in eqx_list has to match the sum
177
+ of the dimension of `t` + the dimension of `x` or the output dimension
178
+ after the `input_transform` function.
179
+ eqx_network_list
180
+ Default is None. A list of eqx.nn.MLP objects with same input
181
+ dimensions. They represent the parallel subnetworks of the PPIN MLP.
182
+ Their respective outputs are concatenated.
183
+ key
184
+ Default is None. Must be provided with `eqx_list_list` if
185
+ `eqx_network_list` is not provided. A JAX random key that will be used
186
+ to initialize the networks parameters.
187
+ eqx_list_list
188
+ Default is None. Must be provided if `eqx_network_list` is not
189
+ provided. A list of `eqx_list` (see `PINN_MLP.create()`). The input dimension must be the
190
+ same for each sub-`eqx_list`. Then the parallel subnetworks can be
191
+ different. Their respective outputs are concatenated.
192
+ input_transform
193
+ A function that will be called before entering the PPINN MLP. Its output(s)
194
+ must match the PPINN MLP inputs (except for the parameters).
195
+ Its inputs are the PPINN MLP inputs (`t` and/or `x` concatenated together)
196
+ and the parameters. Default is no operation.
197
+ output_transform
198
+ This function will be called after exiting
199
+ the PPINN MLP, i.e., on the concatenated outputs of all parallel networks
200
+ Default is no operation.
201
+ slice_solution
202
+ A jnp.s\_ object which indicates which axis of the PPINN MLP output is
203
+ dedicated to the actual equation solution. Default None
204
+ means that slice_solution = the whole PPINN MLP output. This argument is
205
+ useful when the PPINN MLP is also used to output equation parameters for
206
+ example Note that it must be a slice and not an integer (a
207
+ preprocessing of the user provided argument takes care of it).
208
+
209
+
210
+ Returns
211
+ -------
212
+ ppinn
213
+ A PPINN MLP instance
214
+ ppinn.init_params
215
+ An initial set of parameters for the PPINN MLP
216
+
217
+ """
218
+
219
+ if eqx_network_list is None:
220
+ if eqx_list_list is None or key is None:
221
+ raise ValueError(
222
+ "If eqx_network_list is None, then key and eqx_list_list"
223
+ " must be provided"
224
+ )
225
+
226
+ eqx_network_list = []
227
+ for eqx_list in eqx_list_list:
228
+ key, subkey = jax.random.split(key, 2)
229
+ eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
230
+
231
+ ppinn = cls(
232
+ eqx_network=None, # type: ignore
233
+ eqx_network_list=cast(list[eqx.Module], eqx_network_list),
234
+ slice_solution=slice_solution, # type: ignore
235
+ eq_type=eq_type,
236
+ input_transform=input_transform, # type: ignore
237
+ output_transform=output_transform, # type: ignore
238
+ )
239
+ return ppinn, ppinn.init_params
@@ -7,14 +7,16 @@ import pickle
7
7
  import jax
8
8
  import equinox as eqx
9
9
 
10
- from jinns.utils._pinn import create_PINN, PINN
11
- from jinns.utils._spinn import create_SPINN, SPINN
12
- from jinns.utils._hyperpinn import create_HYPERPINN, HYPERPINN
13
- from jinns.parameters._params import Params, ParamsDict
10
+ from jinns.nn._pinn import PINN
11
+ from jinns.nn._spinn import SPINN
12
+ from jinns.nn._mlp import PINN_MLP
13
+ from jinns.nn._spinn_mlp import SPINN_MLP
14
+ from jinns.nn._hyperpinn import HyperPINN
15
+ from jinns.parameters._params import Params
14
16
 
15
17
 
16
18
  def function_to_string(
17
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...]
19
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
18
20
  ) -> tuple[tuple[str, int, int] | str, ...]:
19
21
  """
20
22
  We need this transformation for eqx_list to be pickled
@@ -40,7 +42,7 @@ def function_to_string(
40
42
 
41
43
 
42
44
  def string_to_function(
43
- eqx_list_with_string: tuple[tuple[str, int, int] | str, ...]
45
+ eqx_list_with_string: tuple[tuple[str, int, int] | str, ...],
44
46
  ) -> tuple[tuple[Callable, int, int] | Callable, ...]:
45
47
  """
46
48
  We need this transformation for eqx_list at the loading ("unpickling")
@@ -84,8 +86,8 @@ def string_to_function(
84
86
 
85
87
  def save_pinn(
86
88
  filename: str,
87
- u: PINN | HYPERPINN | SPINN,
88
- params: Params | ParamsDict,
89
+ u: PINN | HyperPINN | SPINN,
90
+ params: Params,
89
91
  kwargs_creation,
90
92
  ):
91
93
  """
@@ -103,15 +105,12 @@ def save_pinn(
103
105
  tree_serialise_leaves`).
104
106
 
105
107
  Equation parameters are saved apart because the initial type of attribute
106
- `params` in PINN / HYPERPINN / SPINN is not `Params` nor `ParamsDict`
108
+ `params` in PINN / HyperPINN / SPINN is not `Params`
107
109
  but `PyTree` as inherited from `eqx.partition`.
108
110
  Therefore, if we want to ensure a proper serialization/deserialization:
109
111
  - we cannot save a `Params` object at this
110
112
  attribute field ; the `Params` object must be split into `Params.nn_params`
111
113
  (type `PyTree`) and `Params.eq_params` (type `dict`).
112
- - in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
113
- the attribute field `params` because it is not a `PyTree` (as expected in
114
- the PINN / HYPERPINN / SPINN signature) but it is still a dictionary.
115
114
 
116
115
  Parameters
117
116
  ----------
@@ -120,28 +119,16 @@ def save_pinn(
120
119
  u
121
120
  The PINN
122
121
  params
123
- Params or ParamsDict to be save
122
+ Params to be saved
124
123
  kwargs_creation
125
124
  The dictionary of arguments that were used to create the PINN, e.g.
126
125
  the layers list, O/PDE type, etc.
127
126
  """
128
- if isinstance(params, Params):
129
- if isinstance(u, HYPERPINN):
130
- u = eqx.tree_at(lambda m: m.params_hyper, u, params)
131
- elif isinstance(u, (PINN, SPINN)):
132
- u = eqx.tree_at(lambda m: m.params, u, params)
133
- eqx.tree_serialise_leaves(filename + "-module.eqx", u)
134
-
135
- elif isinstance(params, ParamsDict):
136
- for key, params_ in params.nn_params.items():
137
- if isinstance(u, HYPERPINN):
138
- u = eqx.tree_at(lambda m: m.params_hyper, u, params_)
139
- elif isinstance(u, (PINN, SPINN)):
140
- u = eqx.tree_at(lambda m: m.params, u, params_)
141
- eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
142
-
143
- else:
144
- raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
127
+ if isinstance(u, HyperPINN):
128
+ u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
129
+ elif isinstance(u, (PINN, SPINN)):
130
+ u = eqx.tree_at(lambda m: m.init_params, u, params)
131
+ eqx.tree_serialise_leaves(filename + "-module.eqx", u)
145
132
 
146
133
  with open(filename + "-eq_params.pkl", "wb") as f:
147
134
  pickle.dump(params.eq_params, f)
@@ -167,9 +154,8 @@ def save_pinn(
167
154
 
168
155
  def load_pinn(
169
156
  filename: str,
170
- type_: Literal["pinn", "hyperpinn", "spinn"],
171
- key_list_for_paramsdict: list[str] = None,
172
- ) -> tuple[eqx.Module, Params | ParamsDict]:
157
+ type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
158
+ ) -> tuple[eqx.Module, Params]:
173
159
  """
174
160
  Load a PINN model. This function needs to access 3 files :
175
161
  `{filename}-module.eqx`, `{filename}-parameters.pkl` and
@@ -187,9 +173,7 @@ def load_pinn(
187
173
  filename
188
174
  Filename (prefix) without extension.
189
175
  type_
190
- Type of model to load. Must be in ["pinn", "hyperpinn", "spinn"].
191
- key_list_for_paramsdict
192
- Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
176
+ Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
193
177
 
194
178
  Returns
195
179
  -------
@@ -207,34 +191,36 @@ def load_pinn(
207
191
  eq_params_reloaded = {}
208
192
  print("No pickle file for equation parameters found!")
209
193
  kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
210
- if type_ == "pinn":
194
+ if type_ == "pinn_mlp":
211
195
  # next line creates a shallow model, the jax arrays are just shapes and
212
196
  # not populated, this just recreates the correct pytree structure
213
- u_reloaded_shallow, _ = eqx.filter_eval_shape(create_PINN, **kwargs_reloaded)
214
- elif type_ == "spinn":
215
- u_reloaded_shallow, _ = eqx.filter_eval_shape(create_SPINN, **kwargs_reloaded)
197
+ u_reloaded_shallow, _ = eqx.filter_eval_shape(
198
+ PINN_MLP.create, **kwargs_reloaded
199
+ )
200
+ elif type_ == "spinn_mlp":
201
+ u_reloaded_shallow, _ = eqx.filter_eval_shape(
202
+ SPINN_MLP.create, **kwargs_reloaded
203
+ )
216
204
  elif type_ == "hyperpinn":
217
205
  kwargs_reloaded["eqx_list_hyper"] = string_to_function(
218
206
  kwargs_reloaded["eqx_list_hyper"]
219
207
  )
220
208
  u_reloaded_shallow, _ = eqx.filter_eval_shape(
221
- create_HYPERPINN, **kwargs_reloaded
209
+ HyperPINN.create, **kwargs_reloaded
222
210
  )
223
211
  else:
224
212
  raise ValueError(f"{type_} is not valid")
225
- if key_list_for_paramsdict is None:
226
- # now the empty structure is populated with the actual saved array values
227
- # stored in the eqx file
228
- u_reloaded = eqx.tree_deserialise_leaves(
229
- filename + "-module.eqx", u_reloaded_shallow
213
+ # now the empty structure is populated with the actual saved array values
214
+ # stored in the eqx file
215
+ u_reloaded = eqx.tree_deserialise_leaves(
216
+ filename + "-module.eqx", u_reloaded_shallow
217
+ )
218
+ if isinstance(u_reloaded, HyperPINN):
219
+ params = Params(
220
+ nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
230
221
  )
222
+ elif isinstance(u_reloaded, (PINN, SPINN)):
231
223
  params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
232
224
  else:
233
- nn_params_dict = {}
234
- for key in key_list_for_paramsdict:
235
- u_reloaded = eqx.tree_deserialise_leaves(
236
- filename + f"-module_{key}.eqx", u_reloaded_shallow
237
- )
238
- nn_params_dict[key] = u_reloaded.init_params
239
- params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
225
+ raise ValueError("Wrong type for u_reloaded")
240
226
  return u_reloaded, params