jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_mlp.py CHANGED
@@ -2,16 +2,30 @@
2
2
  Implements utility function to create PINNs
3
3
  """
4
4
 
5
- from typing import Callable, Literal, Self, Union, Any
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Literal, Self, Union, Any, TYPE_CHECKING, cast
6
8
  from dataclasses import InitVar
7
9
  import jax
8
10
  import equinox as eqx
9
-
11
+ from typing import Protocol
10
12
  from jaxtyping import Array, Key, PyTree, Float
11
13
 
12
14
  from jinns.parameters._params import Params
13
15
  from jinns.nn._pinn import PINN
14
16
 
17
+ if TYPE_CHECKING:
18
+
19
+ class CallableMLPModule(Protocol):
20
+ """
21
+ Basically just a way to add a __call__ to an eqx.Module.
22
+ https://github.com/patrick-kidger/equinox/issues/1002
23
+ We chose the strutural subtyping of protocols instead of subclassing an
24
+ eqx.Module just to add a __call__ here
25
+ """
26
+
27
+ def __call__(self, *_, **__) -> Array: ...
28
+
15
29
 
16
30
  class MLP(eqx.Module):
17
31
  """
@@ -21,7 +35,7 @@ class MLP(eqx.Module):
21
35
  ----------
22
36
  key : InitVar[Key]
23
37
  A jax random key for the layer initializations.
24
- eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
38
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
25
39
  A tuple of tuples of successive equinox modules and activation functions to
26
40
  describe the PINN architecture. The inner tuples must have the eqx module or
27
41
  activation function as first item, other items represents arguments
@@ -29,23 +43,23 @@ class MLP(eqx.Module):
29
43
  The `key` argument need not be given.
30
44
  Thus typical example is `eqx_list=
31
45
  ((eqx.nn.Linear, 2, 20),
32
- jax.nn.tanh,
46
+ (jax.nn.tanh,),
33
47
  (eqx.nn.Linear, 20, 20),
34
- jax.nn.tanh,
48
+ (jax.nn.tanh,),
35
49
  (eqx.nn.Linear, 20, 20),
36
- jax.nn.tanh,
50
+ (jax.nn.tanh,),
37
51
  (eqx.nn.Linear, 20, 1)
38
52
  )`.
39
53
  """
40
54
 
41
55
  key: InitVar[Key] = eqx.field(kw_only=True)
42
- eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
43
- kw_only=True
56
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
57
+ eqx.field(kw_only=True)
44
58
  )
45
59
 
46
60
  # NOTE that the following should NOT be declared as static otherwise the
47
61
  # eqx.partition that we use in the PINN module will misbehave
48
- layers: list[eqx.Module] = eqx.field(init=False)
62
+ layers: list[CallableMLPModule | Callable[[Array], Array]] = eqx.field(init=False)
49
63
 
50
64
  def __post_init__(self, key, eqx_list):
51
65
  self.layers = []
@@ -63,7 +77,7 @@ class MLP(eqx.Module):
63
77
  self.layers.append(l[0](*l[1:], key=subkey))
64
78
  k += 1
65
79
 
66
- def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
80
+ def __call__(self, t: Float[Array, " input_dim"]) -> Float[Array, " output_dim"]:
67
81
  for layer in self.layers:
68
82
  t = layer(t)
69
83
  return t
@@ -81,19 +95,30 @@ class PINN_MLP(PINN):
81
95
  def create(
82
96
  cls,
83
97
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
84
- eqx_network: eqx.nn.MLP = None,
98
+ eqx_network: eqx.nn.MLP | MLP | None = None,
85
99
  key: Key = None,
86
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
87
- input_transform: Callable[
88
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
89
- ] = None,
90
- output_transform: Callable[
91
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
92
- Float[Array, "output_dim"],
93
- ] = None,
94
- slice_solution: slice = None,
100
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
101
+ input_transform: (
102
+ Callable[
103
+ [Float[Array, " input_dim"], Params[Array]],
104
+ Float[Array, " output_dim"],
105
+ ]
106
+ | None
107
+ ) = None,
108
+ output_transform: (
109
+ Callable[
110
+ [
111
+ Float[Array, " input_dim"],
112
+ Float[Array, " output_dim"],
113
+ Params[Array],
114
+ ],
115
+ Float[Array, " output_dim"],
116
+ ]
117
+ | None
118
+ ) = None,
119
+ slice_solution: slice | None = None,
95
120
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
96
- ) -> tuple[Self, PyTree]:
121
+ ) -> tuple[Self, PINN]:
97
122
  r"""
98
123
  Instanciate standard PINN MLP object. The actual NN is either passed as
99
124
  a eqx.nn.MLP (`eqx_network` argument) or constructed as a custom
@@ -179,14 +204,14 @@ class PINN_MLP(PINN):
179
204
  raise ValueError(
180
205
  "If eqx_network is None, then key and eqx_list must be provided"
181
206
  )
182
- eqx_network = MLP(key=key, eqx_list=eqx_list)
207
+ eqx_network = cast(MLP, MLP(key=key, eqx_list=eqx_list))
183
208
 
184
209
  mlp = cls(
185
210
  eqx_network=eqx_network,
186
- slice_solution=slice_solution,
211
+ slice_solution=slice_solution, # type: ignore
187
212
  eq_type=eq_type,
188
- input_transform=input_transform,
189
- output_transform=output_transform,
213
+ input_transform=input_transform, # type: ignore
214
+ output_transform=output_transform, # type: ignore
190
215
  filter_spec=filter_spec,
191
216
  )
192
217
  return mlp, mlp.init_params
jinns/nn/_pinn.py CHANGED
@@ -2,15 +2,19 @@
2
2
  Implement abstract class for PINN architectures
3
3
  """
4
4
 
5
- from typing import Literal, Callable, Union, Any
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Union, Any, Literal, overload
6
8
  from dataclasses import InitVar
7
9
  import equinox as eqx
8
10
  from jaxtyping import Float, Array, PyTree
9
11
  import jax.numpy as jnp
10
- from jinns.parameters._params import Params, ParamsDict
12
+ from jinns.parameters._params import Params
13
+ from jinns.nn._abstract_pinn import AbstractPINN
14
+ from jinns.nn._utils import _PyTree_to_Params
11
15
 
12
16
 
13
- class PINN(eqx.Module):
17
+ class PINN(AbstractPINN):
14
18
  r"""
15
19
  Base class for PINN objects. It can be seen as a wrapper on
16
20
  an `eqx.Module` which actually implement the NN architectures, with extra
@@ -57,12 +61,12 @@ class PINN(eqx.Module):
57
61
  **Note**: the input dimension as given in eqx_list has to match the sum
58
62
  of the dimension of `t` + the dimension of `x` or the output dimension
59
63
  after the `input_transform` function.
60
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
64
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
61
65
  A function that will be called before entering the PINN. Its output(s)
62
66
  must match the PINN inputs (except for the parameters).
63
67
  Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
64
68
  and the parameters. Default is no operation.
65
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
69
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
66
70
  A function with arguments begin the same input as the PINN, the PINN
67
71
  output and the parameter. This function will be called after exiting the PINN.
68
72
  Default is no operation.
@@ -84,16 +88,16 @@ class PINN(eqx.Module):
84
88
  "nonstatio_PDE"]`
85
89
  """
86
90
 
87
- slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
88
91
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
89
92
  static=True, kw_only=True
90
93
  )
94
+ slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
91
95
  input_transform: Callable[
92
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
96
+ [Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]
93
97
  ] = eqx.field(static=True, kw_only=True, default=None)
94
98
  output_transform: Callable[
95
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
96
- Float[Array, "output_dim"],
99
+ [Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]],
100
+ Float[Array, " output_dim"],
97
101
  ] = eqx.field(static=True, kw_only=True, default=None)
98
102
 
99
103
  eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
@@ -101,11 +105,10 @@ class PINN(eqx.Module):
101
105
  static=True, kw_only=True, default=eqx.is_inexact_array
102
106
  )
103
107
 
104
- init_params: PyTree = eqx.field(init=False)
105
- static: PyTree = eqx.field(init=False, static=True)
108
+ init_params: PINN = eqx.field(init=False)
109
+ static: PINN = eqx.field(init=False, static=True)
106
110
 
107
111
  def __post_init__(self, eqx_network):
108
-
109
112
  if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
110
113
  raise RuntimeError("Wrong parameter value for eq_type")
111
114
  # saving the static part of the model and initial parameters
@@ -154,18 +157,32 @@ class PINN(eqx.Module):
154
157
 
155
158
  return network(inputs)
156
159
 
160
+ @overload
161
+ @_PyTree_to_Params
162
+ def __call__(
163
+ self,
164
+ inputs: Float[Array, " input_dim"],
165
+ params: PyTree,
166
+ *args,
167
+ **kwargs,
168
+ ) -> Float[Array, " output_dim"]: ...
169
+
170
+ @_PyTree_to_Params
157
171
  def __call__(
158
172
  self,
159
- inputs: Float[Array, "input_dim"],
160
- params: Params | ParamsDict | PyTree,
173
+ inputs: Float[Array, " input_dim"],
174
+ params: Params[Array],
161
175
  *args,
162
176
  **kwargs,
163
- ) -> Float[Array, "output_dim"]:
177
+ ) -> Float[Array, " output_dim"]:
164
178
  """
165
179
  A proper __call__ implementation performs an eqx.combine here with
166
180
  `params` and `self.static` to recreate the callable eqx.Module
167
181
  architecture. The rest of the content of this function is dependent on
168
182
  the network.
183
+
184
+ Note that that thanks to the decorator, params can also directly be the
185
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
169
186
  """
170
187
 
171
188
  if len(inputs.shape) == 0:
@@ -174,10 +191,7 @@ class PINN(eqx.Module):
174
191
  # DataGenerators)
175
192
  inputs = inputs[None]
176
193
 
177
- try:
178
- model = eqx.combine(params.nn_params, self.static)
179
- except (KeyError, AttributeError, TypeError) as e: # give more flexibility
180
- model = eqx.combine(params, self.static)
194
+ model = eqx.combine(params.nn_params, self.static)
181
195
 
182
196
  # evaluate the model
183
197
  res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)
jinns/nn/_ppinn.py CHANGED
@@ -2,17 +2,20 @@
2
2
  Implements utility function to create PINNs
3
3
  """
4
4
 
5
- from typing import Callable, Literal, Self
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Literal, Self, cast, overload
6
8
  from dataclasses import InitVar
7
9
  import jax
8
10
  import jax.numpy as jnp
9
11
  import equinox as eqx
10
12
 
11
- from jaxtyping import Array, Key, PyTree, Float
13
+ from jaxtyping import Array, Key, Float, PyTree
12
14
 
13
- from jinns.parameters._params import Params, ParamsDict
15
+ from jinns.parameters._params import Params
14
16
  from jinns.nn._pinn import PINN
15
17
  from jinns.nn._mlp import MLP
18
+ from jinns.nn._utils import _PyTree_to_Params
16
19
 
17
20
 
18
21
  class PPINN_MLP(PINN):
@@ -39,12 +42,12 @@ class PPINN_MLP(PINN):
39
42
  **Note**: the input dimension as given in eqx_list has to match the sum
40
43
  of the dimension of `t` + the dimension of `x` or the output dimension
41
44
  after the `input_transform` function.
42
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
45
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
43
46
  A function that will be called before entering the PPINN. Its output(s)
44
47
  must match the PPINN inputs (except for the parameters).
45
48
  Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
46
49
  and the parameters. Default is no operation.
47
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
50
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
48
51
  A function with arguments begin the same input as the PPINN, the PPINN
49
52
  output and the parameter. This function will be called after exiting
50
53
  the PPINN.
@@ -63,25 +66,46 @@ class PPINN_MLP(PINN):
63
66
  """
64
67
 
65
68
  eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
69
+ init_params: tuple[PINN, ...] = eqx.field(
70
+ init=False
71
+ ) # overriding parent attribute type
72
+ static: tuple[PINN, ...] = eqx.field(
73
+ init=False, static=True
74
+ ) # overriding parent attribute type
66
75
 
67
76
  def __post_init__(self, eqx_network, eqx_network_list):
68
77
  super().__post_init__(
69
78
  eqx_network=eqx_network_list[0], # this is not used since it is
70
79
  # overwritten just below
71
80
  )
72
- self.init_params, self.static = (), ()
73
- for eqx_network_ in eqx_network_list:
81
+ params, static = eqx.partition(eqx_network_list[0], self.filter_spec)
82
+ self.init_params, self.static = (params,), (static,)
83
+ for eqx_network_ in eqx_network_list[1:]:
74
84
  params, static = eqx.partition(eqx_network_, self.filter_spec)
75
85
  self.init_params = self.init_params + (params,)
76
86
  self.static = self.static + (static,)
77
87
 
88
+ @overload
89
+ @_PyTree_to_Params
78
90
  def __call__(
79
91
  self,
80
- inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
92
+ inputs: Float[Array, " input_dim"],
81
93
  params: PyTree,
82
- ) -> Float[Array, "output_dim"]:
94
+ *args,
95
+ **kwargs,
96
+ ) -> Float[Array, " output_dim"]: ...
97
+
98
+ @_PyTree_to_Params
99
+ def __call__(
100
+ self,
101
+ inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
102
+ params: Params[Array],
103
+ ) -> Float[Array, " output_dim"]:
83
104
  """
84
105
  Evaluate the PPINN on some inputs with some params.
106
+
107
+ Note that that thanks to the decorator, params can also directly be the
108
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
85
109
  """
86
110
  if len(inputs.shape) == 0:
87
111
  # This can happen often when the user directly provides some
@@ -92,14 +116,14 @@ class PPINN_MLP(PINN):
92
116
 
93
117
  outs = []
94
118
 
95
- try:
96
- for params_, static in zip(params.nn_params, self.static):
97
- model = eqx.combine(params_, static)
98
- outs += [model(transformed_inputs)]
99
- except (KeyError, AttributeError, TypeError) as e:
100
- for params_, static in zip(params, self.static):
101
- model = eqx.combine(params_, static)
102
- outs += [model(transformed_inputs)]
119
+ # try:
120
+ for params_, static in zip(params.nn_params, self.static):
121
+ model = eqx.combine(params_, static)
122
+ outs += [model(transformed_inputs)] # type: ignore
123
+ # except (KeyError, AttributeError, TypeError) as e:
124
+ # for params_, static in zip(params, self.static):
125
+ # model = eqx.combine(params_, static)
126
+ # outs += [model(transformed_inputs)]
103
127
  # Note that below is then a global output transform
104
128
  res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
105
129
 
@@ -112,18 +136,31 @@ class PPINN_MLP(PINN):
112
136
  def create(
113
137
  cls,
114
138
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
115
- eqx_network_list: list[eqx.nn.MLP] = None,
139
+ eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
116
140
  key: Key = None,
117
- eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]] = None,
118
- input_transform: Callable[
119
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
120
- ] = None,
121
- output_transform: Callable[
122
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
123
- Float[Array, "output_dim"],
124
- ] = None,
125
- slice_solution: slice = None,
126
- ) -> tuple[Self, PyTree]:
141
+ eqx_list_list: (
142
+ list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
143
+ ) = None,
144
+ input_transform: (
145
+ Callable[
146
+ [Float[Array, " input_dim"], Params[Array]],
147
+ Float[Array, " output_dim"],
148
+ ]
149
+ | None
150
+ ) = None,
151
+ output_transform: (
152
+ Callable[
153
+ [
154
+ Float[Array, " input_dim"],
155
+ Float[Array, " output_dim"],
156
+ Params[Array],
157
+ ],
158
+ Float[Array, " output_dim"],
159
+ ]
160
+ | None
161
+ ) = None,
162
+ slice_solution: slice | None = None,
163
+ ) -> tuple[Self, tuple[PINN, ...]]:
127
164
  r"""
128
165
  Utility function to create a Parrallel PINN neural network for Jinns.
129
166
 
@@ -189,15 +226,14 @@ class PPINN_MLP(PINN):
189
226
  eqx_network_list = []
190
227
  for eqx_list in eqx_list_list:
191
228
  key, subkey = jax.random.split(key, 2)
192
- print(subkey)
193
229
  eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
194
230
 
195
231
  ppinn = cls(
196
- eqx_network=None,
197
- eqx_network_list=eqx_network_list,
198
- slice_solution=slice_solution,
232
+ eqx_network=None, # type: ignore
233
+ eqx_network_list=cast(list[eqx.Module], eqx_network_list),
234
+ slice_solution=slice_solution, # type: ignore
199
235
  eq_type=eq_type,
200
- input_transform=input_transform,
201
- output_transform=output_transform,
236
+ input_transform=input_transform, # type: ignore
237
+ output_transform=output_transform, # type: ignore
202
238
  )
203
239
  return ppinn, ppinn.init_params
jinns/nn/_save_load.py CHANGED
@@ -12,7 +12,7 @@ from jinns.nn._spinn import SPINN
12
12
  from jinns.nn._mlp import PINN_MLP
13
13
  from jinns.nn._spinn_mlp import SPINN_MLP
14
14
  from jinns.nn._hyperpinn import HyperPINN
15
- from jinns.parameters._params import Params, ParamsDict
15
+ from jinns.parameters._params import Params
16
16
 
17
17
 
18
18
  def function_to_string(
@@ -87,7 +87,7 @@ def string_to_function(
87
87
  def save_pinn(
88
88
  filename: str,
89
89
  u: PINN | HyperPINN | SPINN,
90
- params: Params | ParamsDict,
90
+ params: Params,
91
91
  kwargs_creation,
92
92
  ):
93
93
  """
@@ -105,15 +105,12 @@ def save_pinn(
105
105
  tree_serialise_leaves`).
106
106
 
107
107
  Equation parameters are saved apart because the initial type of attribute
108
- `params` in PINN / HyperPINN / SPINN is not `Params` nor `ParamsDict`
108
+ `params` in PINN / HyperPINN / SPINN is not `Params`
109
109
  but `PyTree` as inherited from `eqx.partition`.
110
110
  Therefore, if we want to ensure a proper serialization/deserialization:
111
111
  - we cannot save a `Params` object at this
112
112
  attribute field ; the `Params` object must be split into `Params.nn_params`
113
113
  (type `PyTree`) and `Params.eq_params` (type `dict`).
114
- - in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
115
- the attribute field `params` because it is not a `PyTree` (as expected in
116
- the PINN / HyperPINN / SPINN signature) but it is still a dictionary.
117
114
 
118
115
  Parameters
119
116
  ----------
@@ -122,28 +119,16 @@ def save_pinn(
122
119
  u
123
120
  The PINN
124
121
  params
125
- Params or ParamsDict to be save
122
+ Params to be saved
126
123
  kwargs_creation
127
124
  The dictionary of arguments that were used to create the PINN, e.g.
128
125
  the layers list, O/PDE type, etc.
129
126
  """
130
- if isinstance(params, Params):
131
- if isinstance(u, HyperPINN):
132
- u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
133
- elif isinstance(u, (PINN, SPINN)):
134
- u = eqx.tree_at(lambda m: m.init_params, u, params)
135
- eqx.tree_serialise_leaves(filename + "-module.eqx", u)
136
-
137
- elif isinstance(params, ParamsDict):
138
- for key, params_ in params.nn_params.items():
139
- if isinstance(u, HyperPINN):
140
- u = eqx.tree_at(lambda m: m.init_params_hyper, u, params_)
141
- elif isinstance(u, (PINN, SPINN)):
142
- u = eqx.tree_at(lambda m: m.init_params, u, params_)
143
- eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
144
-
145
- else:
146
- raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
127
+ if isinstance(u, HyperPINN):
128
+ u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
129
+ elif isinstance(u, (PINN, SPINN)):
130
+ u = eqx.tree_at(lambda m: m.init_params, u, params)
131
+ eqx.tree_serialise_leaves(filename + "-module.eqx", u)
147
132
 
148
133
  with open(filename + "-eq_params.pkl", "wb") as f:
149
134
  pickle.dump(params.eq_params, f)
@@ -170,8 +155,7 @@ def save_pinn(
170
155
  def load_pinn(
171
156
  filename: str,
172
157
  type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
173
- key_list_for_paramsdict: list[str] = None,
174
- ) -> tuple[eqx.Module, Params | ParamsDict]:
158
+ ) -> tuple[eqx.Module, Params]:
175
159
  """
176
160
  Load a PINN model. This function needs to access 3 files :
177
161
  `{filename}-module.eqx`, `{filename}-parameters.pkl` and
@@ -190,8 +174,6 @@ def load_pinn(
190
174
  Filename (prefix) without extension.
191
175
  type_
192
176
  Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
193
- key_list_for_paramsdict
194
- Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
195
177
 
196
178
  Returns
197
179
  -------
@@ -228,29 +210,17 @@ def load_pinn(
228
210
  )
229
211
  else:
230
212
  raise ValueError(f"{type_} is not valid")
231
- if key_list_for_paramsdict is None:
232
- # now the empty structure is populated with the actual saved array values
233
- # stored in the eqx file
234
- u_reloaded = eqx.tree_deserialise_leaves(
235
- filename + "-module.eqx", u_reloaded_shallow
213
+ # now the empty structure is populated with the actual saved array values
214
+ # stored in the eqx file
215
+ u_reloaded = eqx.tree_deserialise_leaves(
216
+ filename + "-module.eqx", u_reloaded_shallow
217
+ )
218
+ if isinstance(u_reloaded, HyperPINN):
219
+ params = Params(
220
+ nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
236
221
  )
237
- if isinstance(u_reloaded, HyperPINN):
238
- params = Params(
239
- nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
240
- )
241
- elif isinstance(u_reloaded, (PINN, SPINN)):
242
- params = Params(
243
- nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded
244
- )
222
+ elif isinstance(u_reloaded, (PINN, SPINN)):
223
+ params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
245
224
  else:
246
- nn_params_dict = {}
247
- for key in key_list_for_paramsdict:
248
- u_reloaded = eqx.tree_deserialise_leaves(
249
- filename + f"-module_{key}.eqx", u_reloaded_shallow
250
- )
251
- if isinstance(u_reloaded, HyperPINN):
252
- nn_params_dict[key] = u_reloaded.init_params_hyper
253
- elif isinstance(u_reloaded, (PINN, SPINN)):
254
- nn_params_dict[key] = u_reloaded.init_params
255
- params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
225
+ raise ValueError("Wrong type for u_reloaded")
256
226
  return u_reloaded, params
jinns/nn/_spinn.py CHANGED
@@ -1,14 +1,17 @@
1
- from typing import Union, Callable, Any
1
+ from __future__ import annotations
2
+ from typing import Union, Callable, Any, Literal, overload
2
3
  from dataclasses import InitVar
3
4
  from jaxtyping import PyTree, Float, Array
4
5
  import jax
5
6
  import jax.numpy as jnp
6
7
  import equinox as eqx
7
8
 
8
- from jinns.parameters._params import Params, ParamsDict
9
+ from jinns.parameters._params import Params
10
+ from jinns.nn._abstract_pinn import AbstractPINN
11
+ from jinns.nn._utils import _PyTree_to_Params
9
12
 
10
13
 
11
- class SPINN(eqx.Module):
14
+ class SPINN(AbstractPINN):
12
15
  """
13
16
  A Separable PINN object compatible with the rest of jinns.
14
17
 
@@ -47,21 +50,21 @@ class SPINN(eqx.Module):
47
50
 
48
51
  """
49
52
 
53
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
54
+ static=True, kw_only=True
55
+ )
50
56
  d: int = eqx.field(static=True, kw_only=True)
51
57
  r: int = eqx.field(static=True, kw_only=True)
52
- eq_type: str = eqx.field(static=True, kw_only=True)
53
58
  m: int = eqx.field(static=True, kw_only=True, default=1)
54
-
55
59
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
56
60
  static=True, kw_only=True, default=None
57
61
  )
58
62
  eqx_spinn_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
59
63
 
60
- init_params: PyTree = eqx.field(init=False)
61
- static: PyTree = eqx.field(init=False, static=True)
64
+ init_params: SPINN = eqx.field(init=False)
65
+ static: SPINN = eqx.field(init=False, static=True)
62
66
 
63
67
  def __post_init__(self, eqx_spinn_network):
64
-
65
68
  if self.filter_spec is None:
66
69
  self.filter_spec = eqx.is_inexact_array
67
70
 
@@ -69,20 +72,34 @@ class SPINN(eqx.Module):
69
72
  eqx_spinn_network, self.filter_spec
70
73
  )
71
74
 
75
+ @overload
76
+ @_PyTree_to_Params
72
77
  def __call__(
73
78
  self,
74
- t_x: Float[Array, "batch_size 1+dim"],
75
- params: Params | ParamsDict | PyTree,
76
- ) -> Float[Array, "output_dim"]:
79
+ inputs: Float[Array, " input_dim"],
80
+ params: PyTree,
81
+ *args,
82
+ **kwargs,
83
+ ) -> Float[Array, " output_dim"]: ...
84
+
85
+ @_PyTree_to_Params
86
+ def __call__(
87
+ self,
88
+ t_x: Float[Array, " batch_size 1+dim"],
89
+ params: Params[Array],
90
+ ) -> Float[Array, " output_dim"]:
77
91
  """
78
92
  Evaluate the SPINN on some inputs with some params.
93
+
94
+ Note that that thanks to the decorator, params can also directly be the
95
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
79
96
  """
80
- try:
81
- spinn = eqx.combine(params.nn_params, self.static)
82
- except (KeyError, AttributeError, TypeError) as e:
83
- spinn = eqx.combine(params, self.static)
97
+ # try:
98
+ spinn = eqx.combine(params.nn_params, self.static)
99
+ # except (KeyError, AttributeError, TypeError) as e:
100
+ # spinn = eqx.combine(params, self.static)
84
101
  v_model = jax.vmap(spinn)
85
- res = v_model(t_x)
102
+ res = v_model(t_x) # type: ignore
86
103
 
87
104
  a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
88
105
  b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])