jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/nn/_ppinn.py CHANGED
@@ -2,17 +2,20 @@
2
2
  Implements utility function to create PINNs
3
3
  """
4
4
 
5
- from typing import Callable, Literal, Self
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Literal, Self, cast, overload
6
8
  from dataclasses import InitVar
7
9
  import jax
8
10
  import jax.numpy as jnp
9
11
  import equinox as eqx
10
12
 
11
- from jaxtyping import Array, Key, PyTree, Float
13
+ from jaxtyping import Array, Key, Float, PyTree
12
14
 
13
- from jinns.parameters._params import Params, ParamsDict
15
+ from jinns.parameters._params import Params
14
16
  from jinns.nn._pinn import PINN
15
17
  from jinns.nn._mlp import MLP
18
+ from jinns.nn._utils import _PyTree_to_Params
16
19
 
17
20
 
18
21
  class PPINN_MLP(PINN):
@@ -39,12 +42,12 @@ class PPINN_MLP(PINN):
39
42
  **Note**: the input dimension as given in eqx_list has to match the sum
40
43
  of the dimension of `t` + the dimension of `x` or the output dimension
41
44
  after the `input_transform` function.
42
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
45
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
43
46
  A function that will be called before entering the PPINN. Its output(s)
44
47
  must match the PPINN inputs (except for the parameters).
45
48
  Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
46
49
  and the parameters. Default is no operation.
47
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
50
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
48
51
  A function with arguments begin the same input as the PPINN, the PPINN
49
52
  output and the parameter. This function will be called after exiting
50
53
  the PPINN.
@@ -63,25 +66,46 @@ class PPINN_MLP(PINN):
63
66
  """
64
67
 
65
68
  eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
69
+ init_params: tuple[PINN, ...] = eqx.field(
70
+ init=False
71
+ ) # overriding parent attribute type
72
+ static: tuple[PINN, ...] = eqx.field(
73
+ init=False, static=True
74
+ ) # overriding parent attribute type
66
75
 
67
76
  def __post_init__(self, eqx_network, eqx_network_list):
68
77
  super().__post_init__(
69
78
  eqx_network=eqx_network_list[0], # this is not used since it is
70
79
  # overwritten just below
71
80
  )
72
- self.init_params, self.static = (), ()
73
- for eqx_network_ in eqx_network_list:
81
+ params, static = eqx.partition(eqx_network_list[0], self.filter_spec)
82
+ self.init_params, self.static = (params,), (static,)
83
+ for eqx_network_ in eqx_network_list[1:]:
74
84
  params, static = eqx.partition(eqx_network_, self.filter_spec)
75
85
  self.init_params = self.init_params + (params,)
76
86
  self.static = self.static + (static,)
77
87
 
88
+ @overload
89
+ @_PyTree_to_Params
78
90
  def __call__(
79
91
  self,
80
- inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
92
+ inputs: Float[Array, " input_dim"],
81
93
  params: PyTree,
82
- ) -> Float[Array, "output_dim"]:
94
+ *args,
95
+ **kwargs,
96
+ ) -> Float[Array, " output_dim"]: ...
97
+
98
+ @_PyTree_to_Params
99
+ def __call__(
100
+ self,
101
+ inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
102
+ params: Params[Array],
103
+ ) -> Float[Array, " output_dim"]:
83
104
  """
84
105
  Evaluate the PPINN on some inputs with some params.
106
+
107
+ Note that that thanks to the decorator, params can also directly be the
108
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
85
109
  """
86
110
  if len(inputs.shape) == 0:
87
111
  # This can happen often when the user directly provides some
@@ -92,14 +116,14 @@ class PPINN_MLP(PINN):
92
116
 
93
117
  outs = []
94
118
 
95
- try:
96
- for params_, static in zip(params.nn_params, self.static):
97
- model = eqx.combine(params_, static)
98
- outs += [model(transformed_inputs)]
99
- except (KeyError, AttributeError, TypeError) as e:
100
- for params_, static in zip(params, self.static):
101
- model = eqx.combine(params_, static)
102
- outs += [model(transformed_inputs)]
119
+ # try:
120
+ for params_, static in zip(params.nn_params, self.static):
121
+ model = eqx.combine(params_, static)
122
+ outs += [model(transformed_inputs)] # type: ignore
123
+ # except (KeyError, AttributeError, TypeError) as e:
124
+ # for params_, static in zip(params, self.static):
125
+ # model = eqx.combine(params_, static)
126
+ # outs += [model(transformed_inputs)]
103
127
  # Note that below is then a global output transform
104
128
  res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
105
129
 
@@ -112,18 +136,31 @@ class PPINN_MLP(PINN):
112
136
  def create(
113
137
  cls,
114
138
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
115
- eqx_network_list: list[eqx.nn.MLP] = None,
139
+ eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
116
140
  key: Key = None,
117
- eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]] = None,
118
- input_transform: Callable[
119
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
120
- ] = None,
121
- output_transform: Callable[
122
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
123
- Float[Array, "output_dim"],
124
- ] = None,
125
- slice_solution: slice = None,
126
- ) -> tuple[Self, PyTree]:
141
+ eqx_list_list: (
142
+ list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
143
+ ) = None,
144
+ input_transform: (
145
+ Callable[
146
+ [Float[Array, " input_dim"], Params[Array]],
147
+ Float[Array, " output_dim"],
148
+ ]
149
+ | None
150
+ ) = None,
151
+ output_transform: (
152
+ Callable[
153
+ [
154
+ Float[Array, " input_dim"],
155
+ Float[Array, " output_dim"],
156
+ Params[Array],
157
+ ],
158
+ Float[Array, " output_dim"],
159
+ ]
160
+ | None
161
+ ) = None,
162
+ slice_solution: slice | None = None,
163
+ ) -> tuple[Self, tuple[PINN, ...]]:
127
164
  r"""
128
165
  Utility function to create a Parrallel PINN neural network for Jinns.
129
166
 
@@ -189,15 +226,14 @@ class PPINN_MLP(PINN):
189
226
  eqx_network_list = []
190
227
  for eqx_list in eqx_list_list:
191
228
  key, subkey = jax.random.split(key, 2)
192
- print(subkey)
193
229
  eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
194
230
 
195
231
  ppinn = cls(
196
- eqx_network=None,
197
- eqx_network_list=eqx_network_list,
198
- slice_solution=slice_solution,
232
+ eqx_network=None, # type: ignore
233
+ eqx_network_list=cast(list[eqx.Module], eqx_network_list),
234
+ slice_solution=slice_solution, # type: ignore
199
235
  eq_type=eq_type,
200
- input_transform=input_transform,
201
- output_transform=output_transform,
236
+ input_transform=input_transform, # type: ignore
237
+ output_transform=output_transform, # type: ignore
202
238
  )
203
239
  return ppinn, ppinn.init_params
jinns/nn/_save_load.py CHANGED
@@ -12,7 +12,7 @@ from jinns.nn._spinn import SPINN
12
12
  from jinns.nn._mlp import PINN_MLP
13
13
  from jinns.nn._spinn_mlp import SPINN_MLP
14
14
  from jinns.nn._hyperpinn import HyperPINN
15
- from jinns.parameters._params import Params, ParamsDict
15
+ from jinns.parameters._params import Params
16
16
 
17
17
 
18
18
  def function_to_string(
@@ -87,7 +87,7 @@ def string_to_function(
87
87
  def save_pinn(
88
88
  filename: str,
89
89
  u: PINN | HyperPINN | SPINN,
90
- params: Params | ParamsDict,
90
+ params: Params,
91
91
  kwargs_creation,
92
92
  ):
93
93
  """
@@ -105,15 +105,12 @@ def save_pinn(
105
105
  tree_serialise_leaves`).
106
106
 
107
107
  Equation parameters are saved apart because the initial type of attribute
108
- `params` in PINN / HyperPINN / SPINN is not `Params` nor `ParamsDict`
108
+ `params` in PINN / HyperPINN / SPINN is not `Params`
109
109
  but `PyTree` as inherited from `eqx.partition`.
110
110
  Therefore, if we want to ensure a proper serialization/deserialization:
111
111
  - we cannot save a `Params` object at this
112
112
  attribute field ; the `Params` object must be split into `Params.nn_params`
113
113
  (type `PyTree`) and `Params.eq_params` (type `dict`).
114
- - in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
115
- the attribute field `params` because it is not a `PyTree` (as expected in
116
- the PINN / HyperPINN / SPINN signature) but it is still a dictionary.
117
114
 
118
115
  Parameters
119
116
  ----------
@@ -122,28 +119,16 @@ def save_pinn(
122
119
  u
123
120
  The PINN
124
121
  params
125
- Params or ParamsDict to be save
122
+ Params to be saved
126
123
  kwargs_creation
127
124
  The dictionary of arguments that were used to create the PINN, e.g.
128
125
  the layers list, O/PDE type, etc.
129
126
  """
130
- if isinstance(params, Params):
131
- if isinstance(u, HyperPINN):
132
- u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
133
- elif isinstance(u, (PINN, SPINN)):
134
- u = eqx.tree_at(lambda m: m.init_params, u, params)
135
- eqx.tree_serialise_leaves(filename + "-module.eqx", u)
136
-
137
- elif isinstance(params, ParamsDict):
138
- for key, params_ in params.nn_params.items():
139
- if isinstance(u, HyperPINN):
140
- u = eqx.tree_at(lambda m: m.init_params_hyper, u, params_)
141
- elif isinstance(u, (PINN, SPINN)):
142
- u = eqx.tree_at(lambda m: m.init_params, u, params_)
143
- eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
144
-
145
- else:
146
- raise ValueError("The parameters to be saved must be a Params or a ParamsDict")
127
+ if isinstance(u, HyperPINN):
128
+ u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
129
+ elif isinstance(u, (PINN, SPINN)):
130
+ u = eqx.tree_at(lambda m: m.init_params, u, params)
131
+ eqx.tree_serialise_leaves(filename + "-module.eqx", u)
147
132
 
148
133
  with open(filename + "-eq_params.pkl", "wb") as f:
149
134
  pickle.dump(params.eq_params, f)
@@ -170,8 +155,7 @@ def save_pinn(
170
155
  def load_pinn(
171
156
  filename: str,
172
157
  type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
173
- key_list_for_paramsdict: list[str] = None,
174
- ) -> tuple[eqx.Module, Params | ParamsDict]:
158
+ ) -> tuple[eqx.Module, Params]:
175
159
  """
176
160
  Load a PINN model. This function needs to access 3 files :
177
161
  `{filename}-module.eqx`, `{filename}-parameters.pkl` and
@@ -190,8 +174,6 @@ def load_pinn(
190
174
  Filename (prefix) without extension.
191
175
  type_
192
176
  Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
193
- key_list_for_paramsdict
194
- Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
195
177
 
196
178
  Returns
197
179
  -------
@@ -228,29 +210,17 @@ def load_pinn(
228
210
  )
229
211
  else:
230
212
  raise ValueError(f"{type_} is not valid")
231
- if key_list_for_paramsdict is None:
232
- # now the empty structure is populated with the actual saved array values
233
- # stored in the eqx file
234
- u_reloaded = eqx.tree_deserialise_leaves(
235
- filename + "-module.eqx", u_reloaded_shallow
213
+ # now the empty structure is populated with the actual saved array values
214
+ # stored in the eqx file
215
+ u_reloaded = eqx.tree_deserialise_leaves(
216
+ filename + "-module.eqx", u_reloaded_shallow
217
+ )
218
+ if isinstance(u_reloaded, HyperPINN):
219
+ params = Params(
220
+ nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
236
221
  )
237
- if isinstance(u_reloaded, HyperPINN):
238
- params = Params(
239
- nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
240
- )
241
- elif isinstance(u_reloaded, (PINN, SPINN)):
242
- params = Params(
243
- nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded
244
- )
222
+ elif isinstance(u_reloaded, (PINN, SPINN)):
223
+ params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
245
224
  else:
246
- nn_params_dict = {}
247
- for key in key_list_for_paramsdict:
248
- u_reloaded = eqx.tree_deserialise_leaves(
249
- filename + f"-module_{key}.eqx", u_reloaded_shallow
250
- )
251
- if isinstance(u_reloaded, HyperPINN):
252
- nn_params_dict[key] = u_reloaded.init_params_hyper
253
- elif isinstance(u_reloaded, (PINN, SPINN)):
254
- nn_params_dict[key] = u_reloaded.init_params
255
- params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
225
+ raise ValueError("Wrong type for u_reloaded")
256
226
  return u_reloaded, params
jinns/nn/_spinn.py CHANGED
@@ -1,14 +1,17 @@
1
- from typing import Union, Callable, Any
1
+ from __future__ import annotations
2
+ from typing import Union, Callable, Any, Literal, overload
2
3
  from dataclasses import InitVar
3
4
  from jaxtyping import PyTree, Float, Array
4
5
  import jax
5
6
  import jax.numpy as jnp
6
7
  import equinox as eqx
7
8
 
8
- from jinns.parameters._params import Params, ParamsDict
9
+ from jinns.parameters._params import Params
10
+ from jinns.nn._abstract_pinn import AbstractPINN
11
+ from jinns.nn._utils import _PyTree_to_Params
9
12
 
10
13
 
11
- class SPINN(eqx.Module):
14
+ class SPINN(AbstractPINN):
12
15
  """
13
16
  A Separable PINN object compatible with the rest of jinns.
14
17
 
@@ -47,21 +50,21 @@ class SPINN(eqx.Module):
47
50
 
48
51
  """
49
52
 
53
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
54
+ static=True, kw_only=True
55
+ )
50
56
  d: int = eqx.field(static=True, kw_only=True)
51
57
  r: int = eqx.field(static=True, kw_only=True)
52
- eq_type: str = eqx.field(static=True, kw_only=True)
53
58
  m: int = eqx.field(static=True, kw_only=True, default=1)
54
-
55
59
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
56
60
  static=True, kw_only=True, default=None
57
61
  )
58
62
  eqx_spinn_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
59
63
 
60
- init_params: PyTree = eqx.field(init=False)
61
- static: PyTree = eqx.field(init=False, static=True)
64
+ init_params: SPINN = eqx.field(init=False)
65
+ static: SPINN = eqx.field(init=False, static=True)
62
66
 
63
67
  def __post_init__(self, eqx_spinn_network):
64
-
65
68
  if self.filter_spec is None:
66
69
  self.filter_spec = eqx.is_inexact_array
67
70
 
@@ -69,20 +72,34 @@ class SPINN(eqx.Module):
69
72
  eqx_spinn_network, self.filter_spec
70
73
  )
71
74
 
75
+ @overload
76
+ @_PyTree_to_Params
72
77
  def __call__(
73
78
  self,
74
- t_x: Float[Array, "batch_size 1+dim"],
75
- params: Params | ParamsDict | PyTree,
76
- ) -> Float[Array, "output_dim"]:
79
+ inputs: Float[Array, " input_dim"],
80
+ params: PyTree,
81
+ *args,
82
+ **kwargs,
83
+ ) -> Float[Array, " output_dim"]: ...
84
+
85
+ @_PyTree_to_Params
86
+ def __call__(
87
+ self,
88
+ t_x: Float[Array, " batch_size 1+dim"],
89
+ params: Params[Array],
90
+ ) -> Float[Array, " output_dim"]:
77
91
  """
78
92
  Evaluate the SPINN on some inputs with some params.
93
+
94
+ Note that that thanks to the decorator, params can also directly be the
95
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
79
96
  """
80
- try:
81
- spinn = eqx.combine(params.nn_params, self.static)
82
- except (KeyError, AttributeError, TypeError) as e:
83
- spinn = eqx.combine(params, self.static)
97
+ # try:
98
+ spinn = eqx.combine(params.nn_params, self.static)
99
+ # except (KeyError, AttributeError, TypeError) as e:
100
+ # spinn = eqx.combine(params, self.static)
84
101
  v_model = jax.vmap(spinn)
85
- res = v_model(t_x)
102
+ res = v_model(t_x) # type: ignore
86
103
 
87
104
  a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
88
105
  b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
jinns/nn/_spinn_mlp.py CHANGED
@@ -4,13 +4,12 @@ https://arxiv.org/abs/2211.08761
4
4
  """
5
5
 
6
6
  from dataclasses import InitVar
7
- from typing import Callable, Literal, Self, Union, Any
7
+ from typing import Callable, Literal, Self, Union, Any, TypeGuard
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  import equinox as eqx
11
11
  from jaxtyping import Key, Array, Float, PyTree
12
12
 
13
- from jinns.parameters._params import Params, ParamsDict
14
13
  from jinns.nn._mlp import MLP
15
14
  from jinns.nn._spinn import SPINN
16
15
 
@@ -26,7 +25,7 @@ class SMLP(eqx.Module):
26
25
  d : int
27
26
  The number of dimensions to treat separately, including time `t` if
28
27
  used for non-stationnary equations.
29
- eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
28
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
30
29
  A tuple of tuples of successive equinox modules and activation functions to
31
30
  describe the PINN architecture. The inner tuples must have the eqx module or
32
31
  activation function as first item, other items represents arguments
@@ -34,18 +33,18 @@ class SMLP(eqx.Module):
34
33
  The `key` argument need not be given.
35
34
  Thus typical example is `eqx_list=
36
35
  ((eqx.nn.Linear, 1, 20),
37
- jax.nn.tanh,
36
+ (jax.nn.tanh,),
38
37
  (eqx.nn.Linear, 20, 20),
39
- jax.nn.tanh,
38
+ (jax.nn.tanh,),
40
39
  (eqx.nn.Linear, 20, 20),
41
- jax.nn.tanh,
40
+ (jax.nn.tanh,),
42
41
  (eqx.nn.Linear, 20, r * m)
43
42
  )`.
44
43
  """
45
44
 
46
45
  key: InitVar[Key] = eqx.field(kw_only=True)
47
- eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
48
- kw_only=True
46
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
47
+ eqx.field(kw_only=True)
49
48
  )
50
49
  d: int = eqx.field(static=True, kw_only=True)
51
50
 
@@ -58,8 +57,8 @@ class SMLP(eqx.Module):
58
57
  ]
59
58
 
60
59
  def __call__(
61
- self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
62
- ) -> Float[Array, "d embed_dim*output_dim"]:
60
+ self, inputs: Float[Array, " dim"] | Float[Array, " dim+1"]
61
+ ) -> Float[Array, " d embed_dim*output_dim"]:
63
62
  outputs = []
64
63
  for d in range(self.d):
65
64
  x_i = inputs[d : d + 1]
@@ -78,11 +77,11 @@ class SPINN_MLP(SPINN):
78
77
  key: Key,
79
78
  d: int,
80
79
  r: int,
81
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
80
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
82
81
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
83
82
  m: int = 1,
84
83
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
85
- ) -> tuple[Self, PyTree]:
84
+ ) -> tuple[Self, SPINN]:
86
85
  """
87
86
  Utility function to create a SPINN neural network with the equinox
88
87
  library.
@@ -108,11 +107,11 @@ class SPINN_MLP(SPINN):
108
107
  The `key` argument need not be given.
109
108
  Thus typical example is
110
109
  `eqx_list=((eqx.nn.Linear, 1, 20),
111
- jax.nn.tanh,
110
+ (jax.nn.tanh,),
111
+ (eqx.nn.Linea)r, 20, 20),
112
+ (jax.nn.tanh,),
112
113
  (eqx.nn.Linear, 20, 20),
113
- jax.nn.tanh,
114
- (eqx.nn.Linear, 20, 20),
115
- jax.nn.tanh,
114
+ (jax.nn.tanh,),
116
115
  (eqx.nn.Linear, 20, r * m)
117
116
  )`.
118
117
  eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
@@ -158,24 +157,31 @@ class SPINN_MLP(SPINN):
158
157
  if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
159
158
  raise RuntimeError("Wrong parameter value for eq_type")
160
159
 
161
- try:
160
+ def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
161
+ return len(element) > 1
162
+
163
+ if element_is_layer(eqx_list[0]):
162
164
  nb_inputs_declared = eqx_list[0][
163
165
  1
164
166
  ] # normally we look for 2nd ele of 1st layer
165
- except IndexError:
167
+ elif element_is_layer(eqx_list[1]):
166
168
  nb_inputs_declared = eqx_list[1][
167
169
  1
168
170
  ] # but we can have, eg, a flatten first layer
169
- if nb_inputs_declared != 1:
171
+ else:
172
+ nb_inputs_declared = None
173
+ if nb_inputs_declared is None or nb_inputs_declared != 1:
170
174
  raise ValueError("Input dim must be set to 1 in SPINN!")
171
175
 
172
- try:
176
+ if element_is_layer(eqx_list[-1]):
173
177
  nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
174
178
  # last layer
175
- except IndexError:
179
+ elif element_is_layer(eqx_list[-2]):
176
180
  nb_outputs_declared = eqx_list[-2][2]
177
181
  # but we can have, eg, a `jnp.exp` last layer
178
- if nb_outputs_declared != r * m:
182
+ else:
183
+ nb_outputs_declared = None
184
+ if nb_outputs_declared is None or nb_outputs_declared != r * m:
179
185
  raise ValueError("Output dim must be set to r * m in SPINN!")
180
186
 
181
187
  if d > 24:
jinns/nn/_utils.py ADDED
@@ -0,0 +1,38 @@
1
+ from typing import Any, ParamSpec, Callable, Concatenate
2
+ from jaxtyping import PyTree, Array
3
+ from jinns.parameters._params import Params
4
+
5
+
6
+ P = ParamSpec("P")
7
+
8
+
9
+ def _PyTree_to_Params(
10
+ call_fun: Callable[
11
+ Concatenate[Any, Any, PyTree | Params[Array], P],
12
+ Any,
13
+ ],
14
+ ) -> Callable[
15
+ Concatenate[Any, Any, PyTree | Params[Array], P],
16
+ Any,
17
+ ]:
18
+ """
19
+ Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
20
+ authorizes the __call__ with `params` being directly be the
21
+ PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
22
+
23
+ This generic approach enables to cleanly handle type hints, up to the small
24
+ effort required to understand type hints for decorators (ie ParamSpec).
25
+ """
26
+
27
+ def wrapper(
28
+ self: Any,
29
+ inputs: Any,
30
+ params: PyTree | Params[Array],
31
+ *args: P.args,
32
+ **kwargs: P.kwargs,
33
+ ):
34
+ if isinstance(params, PyTree) and not isinstance(params, Params):
35
+ params = Params(nn_params=params, eq_params={})
36
+ return call_fun(self, inputs, params, *args, **kwargs)
37
+
38
+ return wrapper
@@ -1,6 +1,13 @@
1
- from ._params import Params, ParamsDict
1
+ from ._params import Params
2
2
  from ._derivative_keys import (
3
3
  DerivativeKeysODE,
4
4
  DerivativeKeysPDEStatio,
5
5
  DerivativeKeysPDENonStatio,
6
6
  )
7
+
8
+ __all__ = [
9
+ "Params",
10
+ "DerivativeKeysODE",
11
+ "DerivativeKeysPDEStatio",
12
+ "DerivativeKeysPDENonStatio",
13
+ ]