jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/nn/_hyperpinn.py CHANGED
@@ -3,9 +3,11 @@ Implements utility function to create HyperPINNs
3
3
  https://arxiv.org/pdf/2111.01008.pdf
4
4
  """
5
5
 
6
+ from __future__ import annotations
7
+
6
8
  import warnings
7
9
  from dataclasses import InitVar
8
- from typing import Callable, Literal, Self, Union, Any
10
+ from typing import Callable, Literal, Self, Union, Any, cast, overload
9
11
  from math import prod
10
12
  import jax
11
13
  import jax.numpy as jnp
@@ -15,12 +17,13 @@ import numpy as onp
15
17
 
16
18
  from jinns.nn._pinn import PINN
17
19
  from jinns.nn._mlp import MLP
18
- from jinns.parameters._params import Params, ParamsDict
20
+ from jinns.parameters._params import Params
21
+ from jinns.nn._utils import _PyTree_to_Params
19
22
 
20
23
 
21
24
  def _get_param_nb(
22
- params: Params,
23
- ) -> tuple[int, list]:
25
+ params: PyTree[Array],
26
+ ) -> tuple[int, list[int]]:
24
27
  """Returns the number of parameters in a Params object and also
25
28
  the cumulative sum when parsing the object.
26
29
 
@@ -48,7 +51,7 @@ class HyperPINN(PINN):
48
51
 
49
52
  Parameters
50
53
  ----------
51
- hyperparams: list = eqx.field(static=True)
54
+ hyperparams: list[str] = eqx.field(static=True)
52
55
  A list of keys from Params.eq_params that will be considered as
53
56
  hyperparameters for metamodeling.
54
57
  hypernet_input_size: int
@@ -72,12 +75,12 @@ class HyperPINN(PINN):
72
75
  **Note**: the input dimension as given in eqx_list has to match the sum
73
76
  of the dimension of `t` + the dimension of `x` or the output dimension
74
77
  after the `input_transform` function
75
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
78
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
76
79
  A function that will be called before entering the PINN. Its output(s)
77
80
  must match the PINN inputs (except for the parameters).
78
81
  Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
79
82
  and the parameters. Default is no operation.
80
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
83
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
81
84
  A function with arguments begin the same input as the PINN, the PINN
82
85
  output and the parameter. This function will be called after exiting the PINN.
83
86
  Default is no operation.
@@ -100,10 +103,10 @@ class HyperPINN(PINN):
100
103
  eqx_hyper_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
101
104
 
102
105
  pinn_params_sum: int = eqx.field(init=False, static=True)
103
- pinn_params_cumsum: list = eqx.field(init=False, static=True)
106
+ pinn_params_cumsum: list[int] = eqx.field(init=False, static=True)
104
107
 
105
- init_params_hyper: PyTree = eqx.field(init=False)
106
- static_hyper: PyTree = eqx.field(init=False, static=True)
108
+ init_params_hyper: HyperPINN = eqx.field(init=False)
109
+ static_hyper: HyperPINN = eqx.field(init=False, static=True)
107
110
 
108
111
  def __post_init__(self, eqx_network, eqx_hyper_network):
109
112
  super().__post_init__(
@@ -115,7 +118,7 @@ class HyperPINN(PINN):
115
118
  )
116
119
  self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.init_params)
117
120
 
118
- def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
121
+ def _hyper_to_pinn(self, hyper_output: Float[Array, " output_dim"]) -> PINN:
119
122
  """
120
123
  From the output of the hypernetwork, transform to a well formed
121
124
  parameters for the pinn network (i.e. with the same PyTree structure as
@@ -142,15 +145,29 @@ class HyperPINN(PINN):
142
145
  is_leaf=lambda x: isinstance(x, jnp.ndarray),
143
146
  )
144
147
 
148
+ @overload
149
+ @_PyTree_to_Params
145
150
  def __call__(
146
151
  self,
147
- inputs: Float[Array, "input_dim"],
148
- params: Params | ParamsDict | PyTree,
152
+ inputs: Float[Array, " input_dim"],
153
+ params: PyTree,
149
154
  *args,
150
155
  **kwargs,
151
- ) -> Float[Array, "output_dim"]:
156
+ ) -> Float[Array, " output_dim"]: ...
157
+
158
+ @_PyTree_to_Params
159
+ def __call__(
160
+ self,
161
+ inputs: Float[Array, " input_dim"],
162
+ params: Params[Array],
163
+ *args,
164
+ **kwargs,
165
+ ) -> Float[Array, " output_dim"]:
152
166
  """
153
167
  Evaluate the HyperPINN on some inputs with some params.
168
+
169
+ Note that that thanks to the decorator, params can also directly be the
170
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
154
171
  """
155
172
  if len(inputs.shape) == 0:
156
173
  # This can happen often when the user directly provides some
@@ -158,16 +175,17 @@ class HyperPINN(PINN):
158
175
  # DataGenerators)
159
176
  inputs = inputs[None]
160
177
 
161
- try:
162
- hyper = eqx.combine(params.nn_params, self.static_hyper)
163
- except (KeyError, AttributeError, TypeError) as e: # give more flexibility
164
- hyper = eqx.combine(params, self.static_hyper)
178
+ # try:
179
+ hyper = eqx.combine(params.nn_params, self.static_hyper)
180
+ # except (KeyError, AttributeError, TypeError) as e: # give more flexibility
181
+ # hyper = eqx.combine(params, self.static_hyper)
165
182
 
166
183
  eq_params_batch = jnp.concatenate(
167
- [params.eq_params[k].flatten() for k in self.hyperparams], axis=0
184
+ [params.eq_params[k].flatten() for k in self.hyperparams],
185
+ axis=0,
168
186
  )
169
187
 
170
- hyper_output = hyper(eq_params_batch)
188
+ hyper_output = hyper(eq_params_batch) # type: ignore
171
189
 
172
190
  pinn_params = self._hyper_to_pinn(hyper_output)
173
191
 
@@ -187,21 +205,34 @@ class HyperPINN(PINN):
187
205
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
188
206
  hyperparams: list[str],
189
207
  hypernet_input_size: int,
190
- eqx_network: eqx.nn.MLP = None,
191
- eqx_hyper_network: eqx.nn.MLP = None,
208
+ eqx_network: eqx.nn.MLP | MLP | None = None,
209
+ eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
192
210
  key: Key = None,
193
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
194
- eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
195
- input_transform: Callable[
196
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
197
- ] = None,
198
- output_transform: Callable[
199
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
200
- Float[Array, "output_dim"],
201
- ] = None,
202
- slice_solution: slice = None,
211
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
212
+ eqx_list_hyper: (
213
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
214
+ ) = None,
215
+ input_transform: (
216
+ Callable[
217
+ [Float[Array, " input_dim"], Params[Array]],
218
+ Float[Array, " output_dim"],
219
+ ]
220
+ | None
221
+ ) = None,
222
+ output_transform: (
223
+ Callable[
224
+ [
225
+ Float[Array, " input_dim"],
226
+ Float[Array, " output_dim"],
227
+ Params[Array],
228
+ ],
229
+ Float[Array, " output_dim"],
230
+ ]
231
+ | None
232
+ ) = None,
233
+ slice_solution: slice | None = None,
203
234
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
204
- ) -> tuple[Self, PyTree]:
235
+ ) -> tuple[Self, HyperPINN]:
205
236
  r"""
206
237
  Utility function to create a standard PINN neural network with the equinox
207
238
  library.
@@ -250,11 +281,11 @@ class HyperPINN(PINN):
250
281
  The `key` argument need not be given.
251
282
  Thus typical example is `eqx_list=
252
283
  ((eqx.nn.Linear, 2, 20),
253
- jax.nn.tanh,
284
+ (jax.nn.tanh,),
254
285
  (eqx.nn.Linear, 20, 20),
255
- jax.nn.tanh,
286
+ (jax.nn.tanh,),
256
287
  (eqx.nn.Linear, 20, 20),
257
- jax.nn.tanh,
288
+ (jax.nn.tanh,),
258
289
  (eqx.nn.Linear, 20, 1)
259
290
  )`.
260
291
  eqx_list_hyper
@@ -268,11 +299,11 @@ class HyperPINN(PINN):
268
299
  The `key` argument need not be given.
269
300
  Thus typical example is `eqx_list=
270
301
  ((eqx.nn.Linear, 2, 20),
271
- jax.nn.tanh,
302
+ (jax.nn.tanh,),
272
303
  (eqx.nn.Linear, 20, 20),
273
- jax.nn.tanh,
304
+ (jax.nn.tanh,),
274
305
  (eqx.nn.Linear, 20, 20),
275
- jax.nn.tanh,
306
+ (jax.nn.tanh,),
276
307
  (eqx.nn.Linear, 20, 1)
277
308
  )`.
278
309
  input_transform
@@ -343,10 +374,13 @@ class HyperPINN(PINN):
343
374
  (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
344
375
  )
345
376
  else:
346
- eqx_list_hyper = (
347
- eqx_list_hyper[:-2]
348
- + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
349
- + eqx_list_hyper[-1]
377
+ eqx_list_hyper = cast(
378
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...],
379
+ (
380
+ eqx_list_hyper[:-2]
381
+ + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
382
+ + eqx_list_hyper[-1]
383
+ ),
350
384
  )
351
385
  if len(eqx_list_hyper[0]) > 1:
352
386
  eqx_list_hyper = (
@@ -357,21 +391,24 @@ class HyperPINN(PINN):
357
391
  ),
358
392
  ) + eqx_list_hyper[1:]
359
393
  else:
360
- eqx_list_hyper = (
361
- eqx_list_hyper[0]
362
- + (
363
- (
364
- (eqx_list_hyper[1][0],)
365
- + (hypernet_input_size,)
366
- + (eqx_list_hyper[1][2],)
367
- ),
368
- )
369
- + eqx_list_hyper[2:]
394
+ eqx_list_hyper = cast(
395
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...],
396
+ (
397
+ eqx_list_hyper[0]
398
+ + (
399
+ (
400
+ (eqx_list_hyper[1][0],)
401
+ + (hypernet_input_size,)
402
+ + (eqx_list_hyper[1][2],) # type: ignore because we suppose that the second element of tuple is nec.of length > 1 since we expect smth like eqx.nn.Linear
403
+ ),
404
+ )
405
+ + eqx_list_hyper[2:]
406
+ ),
370
407
  )
371
408
  key, subkey = jax.random.split(key, 2)
372
409
  # with warnings.catch_warnings():
373
410
  # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
374
- eqx_hyper_network = MLP(key=subkey, eqx_list=eqx_list_hyper)
411
+ eqx_hyper_network = cast(MLP, MLP(key=subkey, eqx_list=eqx_list_hyper))
375
412
 
376
413
  ### End of finetuning the hypernetwork architecture
377
414
 
@@ -386,10 +423,10 @@ class HyperPINN(PINN):
386
423
  hyperpinn = cls(
387
424
  eqx_network=eqx_network,
388
425
  eqx_hyper_network=eqx_hyper_network,
389
- slice_solution=slice_solution,
426
+ slice_solution=slice_solution, # type: ignore
390
427
  eq_type=eq_type,
391
- input_transform=input_transform,
392
- output_transform=output_transform,
428
+ input_transform=input_transform, # type: ignore
429
+ output_transform=output_transform, # type: ignore
393
430
  hyperparams=hyperparams,
394
431
  hypernet_input_size=hypernet_input_size,
395
432
  filter_spec=filter_spec,
jinns/nn/_mlp.py CHANGED
@@ -2,16 +2,30 @@
2
2
  Implements utility function to create PINNs
3
3
  """
4
4
 
5
- from typing import Callable, Literal, Self, Union, Any
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Literal, Self, Union, Any, TYPE_CHECKING, cast
6
8
  from dataclasses import InitVar
7
9
  import jax
8
10
  import equinox as eqx
9
-
11
+ from typing import Protocol
10
12
  from jaxtyping import Array, Key, PyTree, Float
11
13
 
12
14
  from jinns.parameters._params import Params
13
15
  from jinns.nn._pinn import PINN
14
16
 
17
+ if TYPE_CHECKING:
18
+
19
+ class CallableMLPModule(Protocol):
20
+ """
21
+ Basically just a way to add a __call__ to an eqx.Module.
22
+ https://github.com/patrick-kidger/equinox/issues/1002
23
+ We chose the strutural subtyping of protocols instead of subclassing an
24
+ eqx.Module just to add a __call__ here
25
+ """
26
+
27
+ def __call__(self, *_, **__) -> Array: ...
28
+
15
29
 
16
30
  class MLP(eqx.Module):
17
31
  """
@@ -21,7 +35,7 @@ class MLP(eqx.Module):
21
35
  ----------
22
36
  key : InitVar[Key]
23
37
  A jax random key for the layer initializations.
24
- eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
38
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
25
39
  A tuple of tuples of successive equinox modules and activation functions to
26
40
  describe the PINN architecture. The inner tuples must have the eqx module or
27
41
  activation function as first item, other items represents arguments
@@ -29,23 +43,23 @@ class MLP(eqx.Module):
29
43
  The `key` argument need not be given.
30
44
  Thus typical example is `eqx_list=
31
45
  ((eqx.nn.Linear, 2, 20),
32
- jax.nn.tanh,
46
+ (jax.nn.tanh,),
33
47
  (eqx.nn.Linear, 20, 20),
34
- jax.nn.tanh,
48
+ (jax.nn.tanh,),
35
49
  (eqx.nn.Linear, 20, 20),
36
- jax.nn.tanh,
50
+ (jax.nn.tanh,),
37
51
  (eqx.nn.Linear, 20, 1)
38
52
  )`.
39
53
  """
40
54
 
41
55
  key: InitVar[Key] = eqx.field(kw_only=True)
42
- eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
43
- kw_only=True
56
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
57
+ eqx.field(kw_only=True)
44
58
  )
45
59
 
46
60
  # NOTE that the following should NOT be declared as static otherwise the
47
61
  # eqx.partition that we use in the PINN module will misbehave
48
- layers: list[eqx.Module] = eqx.field(init=False)
62
+ layers: list[CallableMLPModule | Callable[[Array], Array]] = eqx.field(init=False)
49
63
 
50
64
  def __post_init__(self, key, eqx_list):
51
65
  self.layers = []
@@ -63,7 +77,7 @@ class MLP(eqx.Module):
63
77
  self.layers.append(l[0](*l[1:], key=subkey))
64
78
  k += 1
65
79
 
66
- def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
80
+ def __call__(self, t: Float[Array, " input_dim"]) -> Float[Array, " output_dim"]:
67
81
  for layer in self.layers:
68
82
  t = layer(t)
69
83
  return t
@@ -81,19 +95,30 @@ class PINN_MLP(PINN):
81
95
  def create(
82
96
  cls,
83
97
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
84
- eqx_network: eqx.nn.MLP = None,
98
+ eqx_network: eqx.nn.MLP | MLP | None = None,
85
99
  key: Key = None,
86
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
87
- input_transform: Callable[
88
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
89
- ] = None,
90
- output_transform: Callable[
91
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
92
- Float[Array, "output_dim"],
93
- ] = None,
94
- slice_solution: slice = None,
100
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
101
+ input_transform: (
102
+ Callable[
103
+ [Float[Array, " input_dim"], Params[Array]],
104
+ Float[Array, " output_dim"],
105
+ ]
106
+ | None
107
+ ) = None,
108
+ output_transform: (
109
+ Callable[
110
+ [
111
+ Float[Array, " input_dim"],
112
+ Float[Array, " output_dim"],
113
+ Params[Array],
114
+ ],
115
+ Float[Array, " output_dim"],
116
+ ]
117
+ | None
118
+ ) = None,
119
+ slice_solution: slice | None = None,
95
120
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
96
- ) -> tuple[Self, PyTree]:
121
+ ) -> tuple[Self, PINN]:
97
122
  r"""
98
123
  Instanciate standard PINN MLP object. The actual NN is either passed as
99
124
  a eqx.nn.MLP (`eqx_network` argument) or constructed as a custom
@@ -179,14 +204,14 @@ class PINN_MLP(PINN):
179
204
  raise ValueError(
180
205
  "If eqx_network is None, then key and eqx_list must be provided"
181
206
  )
182
- eqx_network = MLP(key=key, eqx_list=eqx_list)
207
+ eqx_network = cast(MLP, MLP(key=key, eqx_list=eqx_list))
183
208
 
184
209
  mlp = cls(
185
210
  eqx_network=eqx_network,
186
- slice_solution=slice_solution,
211
+ slice_solution=slice_solution, # type: ignore
187
212
  eq_type=eq_type,
188
- input_transform=input_transform,
189
- output_transform=output_transform,
213
+ input_transform=input_transform, # type: ignore
214
+ output_transform=output_transform, # type: ignore
190
215
  filter_spec=filter_spec,
191
216
  )
192
217
  return mlp, mlp.init_params
jinns/nn/_pinn.py CHANGED
@@ -2,15 +2,19 @@
2
2
  Implement abstract class for PINN architectures
3
3
  """
4
4
 
5
- from typing import Literal, Callable, Union, Any
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Union, Any, Literal, overload
6
8
  from dataclasses import InitVar
7
9
  import equinox as eqx
8
10
  from jaxtyping import Float, Array, PyTree
9
11
  import jax.numpy as jnp
10
- from jinns.parameters._params import Params, ParamsDict
12
+ from jinns.parameters._params import Params
13
+ from jinns.nn._abstract_pinn import AbstractPINN
14
+ from jinns.nn._utils import _PyTree_to_Params
11
15
 
12
16
 
13
- class PINN(eqx.Module):
17
+ class PINN(AbstractPINN):
14
18
  r"""
15
19
  Base class for PINN objects. It can be seen as a wrapper on
16
20
  an `eqx.Module` which actually implement the NN architectures, with extra
@@ -57,12 +61,12 @@ class PINN(eqx.Module):
57
61
  **Note**: the input dimension as given in eqx_list has to match the sum
58
62
  of the dimension of `t` + the dimension of `x` or the output dimension
59
63
  after the `input_transform` function.
60
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
64
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
61
65
  A function that will be called before entering the PINN. Its output(s)
62
66
  must match the PINN inputs (except for the parameters).
63
67
  Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
64
68
  and the parameters. Default is no operation.
65
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
69
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
66
70
  A function with arguments begin the same input as the PINN, the PINN
67
71
  output and the parameter. This function will be called after exiting the PINN.
68
72
  Default is no operation.
@@ -84,16 +88,16 @@ class PINN(eqx.Module):
84
88
  "nonstatio_PDE"]`
85
89
  """
86
90
 
87
- slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
88
91
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
89
92
  static=True, kw_only=True
90
93
  )
94
+ slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
91
95
  input_transform: Callable[
92
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
96
+ [Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]
93
97
  ] = eqx.field(static=True, kw_only=True, default=None)
94
98
  output_transform: Callable[
95
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
96
- Float[Array, "output_dim"],
99
+ [Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]],
100
+ Float[Array, " output_dim"],
97
101
  ] = eqx.field(static=True, kw_only=True, default=None)
98
102
 
99
103
  eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
@@ -101,11 +105,10 @@ class PINN(eqx.Module):
101
105
  static=True, kw_only=True, default=eqx.is_inexact_array
102
106
  )
103
107
 
104
- init_params: PyTree = eqx.field(init=False)
105
- static: PyTree = eqx.field(init=False, static=True)
108
+ init_params: PINN = eqx.field(init=False)
109
+ static: PINN = eqx.field(init=False, static=True)
106
110
 
107
111
  def __post_init__(self, eqx_network):
108
-
109
112
  if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
110
113
  raise RuntimeError("Wrong parameter value for eq_type")
111
114
  # saving the static part of the model and initial parameters
@@ -154,18 +157,32 @@ class PINN(eqx.Module):
154
157
 
155
158
  return network(inputs)
156
159
 
160
+ @overload
161
+ @_PyTree_to_Params
162
+ def __call__(
163
+ self,
164
+ inputs: Float[Array, " input_dim"],
165
+ params: PyTree,
166
+ *args,
167
+ **kwargs,
168
+ ) -> Float[Array, " output_dim"]: ...
169
+
170
+ @_PyTree_to_Params
157
171
  def __call__(
158
172
  self,
159
- inputs: Float[Array, "input_dim"],
160
- params: Params | ParamsDict | PyTree,
173
+ inputs: Float[Array, " input_dim"],
174
+ params: Params[Array],
161
175
  *args,
162
176
  **kwargs,
163
- ) -> Float[Array, "output_dim"]:
177
+ ) -> Float[Array, " output_dim"]:
164
178
  """
165
179
  A proper __call__ implementation performs an eqx.combine here with
166
180
  `params` and `self.static` to recreate the callable eqx.Module
167
181
  architecture. The rest of the content of this function is dependent on
168
182
  the network.
183
+
184
+ Note that that thanks to the decorator, params can also directly be the
185
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
169
186
  """
170
187
 
171
188
  if len(inputs.shape) == 0:
@@ -174,10 +191,7 @@ class PINN(eqx.Module):
174
191
  # DataGenerators)
175
192
  inputs = inputs[None]
176
193
 
177
- try:
178
- model = eqx.combine(params.nn_params, self.static)
179
- except (KeyError, AttributeError, TypeError) as e: # give more flexibility
180
- model = eqx.combine(params, self.static)
194
+ model = eqx.combine(params.nn_params, self.static)
181
195
 
182
196
  # evaluate the model
183
197
  res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)