jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_operators.py CHANGED
@@ -2,22 +2,42 @@
2
2
  Implements diverse operators for dynamic losses
3
3
  """
4
4
 
5
- from typing import Literal
5
+ from __future__ import (
6
+ annotations,
7
+ )
8
+
9
+ from typing import Literal, cast, Callable
6
10
 
7
11
  import jax
8
12
  import jax.numpy as jnp
9
13
  from jax import grad
10
- import equinox as eqx
11
14
  from jaxtyping import Float, Array
12
15
  from jinns.parameters._params import Params
16
+ from jinns.nn._abstract_pinn import AbstractPINN
17
+
18
+
19
+ def _get_eq_type(
20
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
21
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
22
+ ) -> Literal["nonstatio_PDE", "statio_PDE"]:
23
+ """
24
+ But we filter out ODE from eq_type because we only have operators that does
25
+ not work with ODEs so far
26
+ """
27
+ if isinstance(u, AbstractPINN):
28
+ assert u.eq_type != "ODE", "Cannot compute the operator for ODE PINNs"
29
+ return u.eq_type
30
+ if eq_type is None:
31
+ raise ValueError("eq_type could not be set!")
32
+ return eq_type
13
33
 
14
34
 
15
35
  def divergence_rev(
16
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
17
- u: eqx.Module,
18
- params: Params,
19
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
20
- ) -> float:
36
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
37
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
38
+ params: Params[Array],
39
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
40
+ ) -> Float[Array, " "]:
21
41
  r"""
22
42
  Compute the divergence of a vector field $\mathbf{u}$, i.e.,
23
43
  $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
@@ -41,13 +61,7 @@ def divergence_rev(
41
61
  can know that by inspecting the `u` argument (PINN object). But if `u` is
42
62
  a function, we must set this attribute.
43
63
  """
44
-
45
- try:
46
- eq_type = u.eq_type
47
- except AttributeError:
48
- pass # use the value passed as argument
49
- if eq_type is None:
50
- raise ValueError("eq_type could not be set!")
64
+ eq_type = _get_eq_type(u, eq_type)
51
65
 
52
66
  def scan_fun(_, i):
53
67
  if eq_type == "nonstatio_PDE":
@@ -70,11 +84,11 @@ def divergence_rev(
70
84
 
71
85
 
72
86
  def divergence_fwd(
73
- inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
74
- u: eqx.Module,
75
- params: Params,
76
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
77
- ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
87
+ inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
88
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
89
+ params: Params[Array],
90
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
91
+ ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
78
92
  r"""
79
93
  Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
80
94
  $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
@@ -103,13 +117,7 @@ def divergence_fwd(
103
117
  can know that by inspecting the `u` argument (PINN object). But if `u` is
104
118
  a function, we must set this attribute.
105
119
  """
106
-
107
- try:
108
- eq_type = u.eq_type
109
- except AttributeError:
110
- pass # use the value passed as argument
111
- if eq_type is None:
112
- raise ValueError("eq_type could not be set!")
120
+ eq_type = _get_eq_type(u, eq_type)
113
121
 
114
122
  def scan_fun(_, i):
115
123
  if eq_type == "nonstatio_PDE":
@@ -142,12 +150,12 @@ def divergence_fwd(
142
150
 
143
151
 
144
152
  def laplacian_rev(
145
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
146
- u: eqx.Module,
147
- params: Params,
153
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
154
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
155
+ params: Params[Array],
148
156
  method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
149
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
150
- ) -> float:
157
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
158
+ ) -> Float[Array, " "]:
151
159
  r"""
152
160
  Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
153
161
  to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
@@ -180,13 +188,7 @@ def laplacian_rev(
180
188
  can know that by inspecting the `u` argument (PINN object). But if `u` is
181
189
  a function, we must set this attribute.
182
190
  """
183
-
184
- try:
185
- eq_type = u.eq_type
186
- except AttributeError:
187
- pass # use the value passed as argument
188
- if eq_type is None:
189
- raise ValueError("eq_type could not be set!")
191
+ eq_type = _get_eq_type(u, eq_type)
190
192
 
191
193
  if method == "trace_hessian_x":
192
194
  # NOTE we afford a concatenate here to avoid computing Hessian elements for
@@ -226,16 +228,12 @@ def laplacian_rev(
226
228
  if eq_type == "nonstatio_PDE":
227
229
  d2u_dxi2 = grad(
228
230
  lambda inputs: grad(u_)(inputs)[1 + i],
229
- )(
230
- inputs
231
- )[1 + i]
231
+ )(inputs)[1 + i]
232
232
  else:
233
233
  d2u_dxi2 = grad(
234
234
  lambda inputs: grad(u_, 0)(inputs)[i],
235
235
  0,
236
- )(
237
- inputs
238
- )[i]
236
+ )(inputs)[i]
239
237
  return _, d2u_dxi2
240
238
 
241
239
  if eq_type == "nonstatio_PDE":
@@ -251,12 +249,12 @@ def laplacian_rev(
251
249
 
252
250
 
253
251
  def laplacian_fwd(
254
- inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
255
- u: eqx.Module,
256
- params: Params,
252
+ inputs: Float[Array, " batch_size 1+dim"] | Float[Array, " batch_size dim"],
253
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
254
+ params: Params[Array],
257
255
  method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
258
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
259
- ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
256
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
257
+ ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
260
258
  r"""
261
259
  Compute the Laplacian of a **batched** scalar field $u$
262
260
  from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
@@ -299,13 +297,7 @@ def laplacian_fwd(
299
297
  can know that by inspecting the `u` argument (PINN object). But if `u` is
300
298
  a function, we must set this attribute.
301
299
  """
302
-
303
- try:
304
- eq_type = u.eq_type
305
- except AttributeError:
306
- pass # use the value passed as argument
307
- if eq_type is None:
308
- raise ValueError("eq_type could not be set!")
300
+ eq_type = _get_eq_type(u, eq_type)
309
301
 
310
302
  if method == "loop":
311
303
 
@@ -398,11 +390,12 @@ def laplacian_fwd(
398
390
 
399
391
 
400
392
  def vectorial_laplacian_rev(
401
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
402
- u: eqx.Module,
403
- params: Params,
404
- dim_out: int = None,
405
- ) -> Float[Array, "dim_out"]:
393
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
394
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
395
+ params: Params[Array],
396
+ dim_out: int | None = None,
397
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
398
+ ) -> Float[Array, " dim_out"]:
406
399
  r"""
407
400
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
408
401
  $\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
@@ -426,7 +419,12 @@ def vectorial_laplacian_rev(
426
419
  dim_out
427
420
  Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
428
421
  provided if it is different than that of $\mathrm{inputs}$.
422
+ eq_type
423
+ whether we consider a stationary or non stationary PINN. Most often we
424
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
425
+ a function, we must set this attribute.
429
426
  """
427
+ eq_type = _get_eq_type(u, eq_type)
430
428
  if dim_out is None:
431
429
  dim_out = inputs.shape[0]
432
430
 
@@ -435,7 +433,9 @@ def vectorial_laplacian_rev(
435
433
  # each of these components
436
434
  # Note the jnp.expand_dims call
437
435
  uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
438
- lap_on_j = laplacian_rev(inputs, uj, params, eq_type=u.eq_type)
436
+ lap_on_j = laplacian_rev(
437
+ inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
438
+ )
439
439
 
440
440
  return _, lap_on_j
441
441
 
@@ -444,11 +444,12 @@ def vectorial_laplacian_rev(
444
444
 
445
445
 
446
446
  def vectorial_laplacian_fwd(
447
- inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
448
- u: eqx.Module,
449
- params: Params,
450
- dim_out: int = None,
451
- ) -> Float[Array, "batch_size * (1+dim) n"] | Float[Array, "batch_size * (dim) n"]:
447
+ inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
448
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
449
+ params: Params[Array],
450
+ dim_out: int | None = None,
451
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
452
+ ) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
452
453
  r"""
453
454
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
454
455
  `u` is a SPINN, in this case, it corresponds to a vector
@@ -474,7 +475,12 @@ def vectorial_laplacian_fwd(
474
475
  dim_out
475
476
  the value of the output dimension ($n$ in the formula above). Must be
476
477
  set if different from $d$.
478
+ eq_type
479
+ whether we consider a stationary or non stationary PINN. Most often we
480
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
481
+ a function, we must set this attribute.
477
482
  """
483
+ eq_type = _get_eq_type(u, eq_type)
478
484
  if dim_out is None:
479
485
  dim_out = inputs.shape[0]
480
486
 
@@ -483,7 +489,9 @@ def vectorial_laplacian_fwd(
483
489
  # each of these components
484
490
  # Note the expand_dims
485
491
  uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
486
- lap_on_j = laplacian_fwd(inputs, uj, params, eq_type=u.eq_type)
492
+ lap_on_j = laplacian_fwd(
493
+ inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
494
+ )
487
495
 
488
496
  return _, lap_on_j
489
497
 
@@ -492,8 +500,10 @@ def vectorial_laplacian_fwd(
492
500
 
493
501
 
494
502
  def _u_dot_nabla_times_u_rev(
495
- x: Float[Array, "2"], u: eqx.Module, params: Params
496
- ) -> Float[Array, "2"]:
503
+ x: Float[Array, " 2"],
504
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
505
+ params: Params[Array],
506
+ ) -> Float[Array, " 2"]:
497
507
  r"""
498
508
  Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
499
509
  $\mathbf{x}$ of arbitrary
@@ -522,10 +532,10 @@ def _u_dot_nabla_times_u_rev(
522
532
 
523
533
 
524
534
  def _u_dot_nabla_times_u_fwd(
525
- x: Float[Array, "batch_size 2"],
526
- u: eqx.Module,
527
- params: Params,
528
- ) -> Float[Array, "batch_size batch_size 2"]:
535
+ x: Float[Array, " batch_size 2"],
536
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
537
+ params: Params[Array],
538
+ ) -> Float[Array, " batch_size batch_size 2"]:
529
539
  r"""
530
540
  Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
531
541
  :math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
jinns/nn/__init__.py CHANGED
@@ -1,7 +1,22 @@
1
1
  from ._save_load import save_pinn, load_pinn
2
+ from ._abstract_pinn import AbstractPINN
2
3
  from ._pinn import PINN
3
4
  from ._spinn import SPINN
4
5
  from ._mlp import PINN_MLP, MLP
5
6
  from ._spinn_mlp import SPINN_MLP, SMLP
6
7
  from ._hyperpinn import HyperPINN
7
8
  from ._ppinn import PPINN_MLP
9
+
10
+ __all__ = [
11
+ "save_pinn",
12
+ "load_pinn",
13
+ "AbstractPINN",
14
+ "PINN",
15
+ "SPINN",
16
+ "PINN_MLP",
17
+ "MLP",
18
+ "SPINN_MLP",
19
+ "SMLP",
20
+ "HyperPINN",
21
+ "PPINN_MLP",
22
+ ]
@@ -0,0 +1,22 @@
1
+ import abc
2
+ from typing import Literal, Any
3
+ from jaxtyping import Array
4
+ import equinox as eqx
5
+
6
+ from jinns.nn._utils import _PyTree_to_Params
7
+ from jinns.parameters._params import Params
8
+
9
+
10
+ class AbstractPINN(eqx.Module):
11
+ """
12
+ Basically just a way to add a __call__ to an eqx.Module.
13
+ The way to go for correct type hints apparently
14
+ https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
15
+ """
16
+
17
+ eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
18
+
19
+ @abc.abstractmethod
20
+ @_PyTree_to_Params
21
+ def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
22
+ pass
jinns/nn/_hyperpinn.py CHANGED
@@ -3,9 +3,11 @@ Implements utility function to create HyperPINNs
3
3
  https://arxiv.org/pdf/2111.01008.pdf
4
4
  """
5
5
 
6
+ from __future__ import annotations
7
+
6
8
  import warnings
7
9
  from dataclasses import InitVar
8
- from typing import Callable, Literal, Self, Union, Any
10
+ from typing import Callable, Literal, Self, Union, Any, cast, overload
9
11
  from math import prod
10
12
  import jax
11
13
  import jax.numpy as jnp
@@ -15,12 +17,13 @@ import numpy as onp
15
17
 
16
18
  from jinns.nn._pinn import PINN
17
19
  from jinns.nn._mlp import MLP
18
- from jinns.parameters._params import Params, ParamsDict
20
+ from jinns.parameters._params import Params
21
+ from jinns.nn._utils import _PyTree_to_Params
19
22
 
20
23
 
21
24
  def _get_param_nb(
22
- params: Params,
23
- ) -> tuple[int, list]:
25
+ params: PyTree[Array],
26
+ ) -> tuple[int, list[int]]:
24
27
  """Returns the number of parameters in a Params object and also
25
28
  the cumulative sum when parsing the object.
26
29
 
@@ -48,7 +51,7 @@ class HyperPINN(PINN):
48
51
 
49
52
  Parameters
50
53
  ----------
51
- hyperparams: list = eqx.field(static=True)
54
+ hyperparams: list[str] = eqx.field(static=True)
52
55
  A list of keys from Params.eq_params that will be considered as
53
56
  hyperparameters for metamodeling.
54
57
  hypernet_input_size: int
@@ -72,12 +75,12 @@ class HyperPINN(PINN):
72
75
  **Note**: the input dimension as given in eqx_list has to match the sum
73
76
  of the dimension of `t` + the dimension of `x` or the output dimension
74
77
  after the `input_transform` function
75
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
78
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
76
79
  A function that will be called before entering the PINN. Its output(s)
77
80
  must match the PINN inputs (except for the parameters).
78
81
  Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
79
82
  and the parameters. Default is no operation.
80
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
83
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
81
84
  A function with arguments begin the same input as the PINN, the PINN
82
85
  output and the parameter. This function will be called after exiting the PINN.
83
86
  Default is no operation.
@@ -100,10 +103,10 @@ class HyperPINN(PINN):
100
103
  eqx_hyper_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
101
104
 
102
105
  pinn_params_sum: int = eqx.field(init=False, static=True)
103
- pinn_params_cumsum: list = eqx.field(init=False, static=True)
106
+ pinn_params_cumsum: list[int] = eqx.field(init=False, static=True)
104
107
 
105
- init_params_hyper: PyTree = eqx.field(init=False)
106
- static_hyper: PyTree = eqx.field(init=False, static=True)
108
+ init_params_hyper: HyperPINN = eqx.field(init=False)
109
+ static_hyper: HyperPINN = eqx.field(init=False, static=True)
107
110
 
108
111
  def __post_init__(self, eqx_network, eqx_hyper_network):
109
112
  super().__post_init__(
@@ -115,7 +118,7 @@ class HyperPINN(PINN):
115
118
  )
116
119
  self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.init_params)
117
120
 
118
- def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
121
+ def _hyper_to_pinn(self, hyper_output: Float[Array, " output_dim"]) -> PINN:
119
122
  """
120
123
  From the output of the hypernetwork, transform to a well formed
121
124
  parameters for the pinn network (i.e. with the same PyTree structure as
@@ -142,15 +145,29 @@ class HyperPINN(PINN):
142
145
  is_leaf=lambda x: isinstance(x, jnp.ndarray),
143
146
  )
144
147
 
148
+ @overload
149
+ @_PyTree_to_Params
145
150
  def __call__(
146
151
  self,
147
- inputs: Float[Array, "input_dim"],
148
- params: Params | ParamsDict | PyTree,
152
+ inputs: Float[Array, " input_dim"],
153
+ params: PyTree,
149
154
  *args,
150
155
  **kwargs,
151
- ) -> Float[Array, "output_dim"]:
156
+ ) -> Float[Array, " output_dim"]: ...
157
+
158
+ @_PyTree_to_Params
159
+ def __call__(
160
+ self,
161
+ inputs: Float[Array, " input_dim"],
162
+ params: Params[Array],
163
+ *args,
164
+ **kwargs,
165
+ ) -> Float[Array, " output_dim"]:
152
166
  """
153
167
  Evaluate the HyperPINN on some inputs with some params.
168
+
169
+ Note that that thanks to the decorator, params can also directly be the
170
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
154
171
  """
155
172
  if len(inputs.shape) == 0:
156
173
  # This can happen often when the user directly provides some
@@ -158,16 +175,17 @@ class HyperPINN(PINN):
158
175
  # DataGenerators)
159
176
  inputs = inputs[None]
160
177
 
161
- try:
162
- hyper = eqx.combine(params.nn_params, self.static_hyper)
163
- except (KeyError, AttributeError, TypeError) as e: # give more flexibility
164
- hyper = eqx.combine(params, self.static_hyper)
178
+ # try:
179
+ hyper = eqx.combine(params.nn_params, self.static_hyper)
180
+ # except (KeyError, AttributeError, TypeError) as e: # give more flexibility
181
+ # hyper = eqx.combine(params, self.static_hyper)
165
182
 
166
183
  eq_params_batch = jnp.concatenate(
167
- [params.eq_params[k].flatten() for k in self.hyperparams], axis=0
184
+ [params.eq_params[k].flatten() for k in self.hyperparams],
185
+ axis=0,
168
186
  )
169
187
 
170
- hyper_output = hyper(eq_params_batch)
188
+ hyper_output = hyper(eq_params_batch) # type: ignore
171
189
 
172
190
  pinn_params = self._hyper_to_pinn(hyper_output)
173
191
 
@@ -187,21 +205,34 @@ class HyperPINN(PINN):
187
205
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
188
206
  hyperparams: list[str],
189
207
  hypernet_input_size: int,
190
- eqx_network: eqx.nn.MLP = None,
191
- eqx_hyper_network: eqx.nn.MLP = None,
208
+ eqx_network: eqx.nn.MLP | MLP | None = None,
209
+ eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
192
210
  key: Key = None,
193
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
194
- eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
195
- input_transform: Callable[
196
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
197
- ] = None,
198
- output_transform: Callable[
199
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
200
- Float[Array, "output_dim"],
201
- ] = None,
202
- slice_solution: slice = None,
211
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
212
+ eqx_list_hyper: (
213
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
214
+ ) = None,
215
+ input_transform: (
216
+ Callable[
217
+ [Float[Array, " input_dim"], Params[Array]],
218
+ Float[Array, " output_dim"],
219
+ ]
220
+ | None
221
+ ) = None,
222
+ output_transform: (
223
+ Callable[
224
+ [
225
+ Float[Array, " input_dim"],
226
+ Float[Array, " output_dim"],
227
+ Params[Array],
228
+ ],
229
+ Float[Array, " output_dim"],
230
+ ]
231
+ | None
232
+ ) = None,
233
+ slice_solution: slice | None = None,
203
234
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
204
- ) -> tuple[Self, PyTree]:
235
+ ) -> tuple[Self, HyperPINN]:
205
236
  r"""
206
237
  Utility function to create a standard PINN neural network with the equinox
207
238
  library.
@@ -250,11 +281,11 @@ class HyperPINN(PINN):
250
281
  The `key` argument need not be given.
251
282
  Thus typical example is `eqx_list=
252
283
  ((eqx.nn.Linear, 2, 20),
253
- jax.nn.tanh,
284
+ (jax.nn.tanh,),
254
285
  (eqx.nn.Linear, 20, 20),
255
- jax.nn.tanh,
286
+ (jax.nn.tanh,),
256
287
  (eqx.nn.Linear, 20, 20),
257
- jax.nn.tanh,
288
+ (jax.nn.tanh,),
258
289
  (eqx.nn.Linear, 20, 1)
259
290
  )`.
260
291
  eqx_list_hyper
@@ -268,11 +299,11 @@ class HyperPINN(PINN):
268
299
  The `key` argument need not be given.
269
300
  Thus typical example is `eqx_list=
270
301
  ((eqx.nn.Linear, 2, 20),
271
- jax.nn.tanh,
302
+ (jax.nn.tanh,),
272
303
  (eqx.nn.Linear, 20, 20),
273
- jax.nn.tanh,
304
+ (jax.nn.tanh,),
274
305
  (eqx.nn.Linear, 20, 20),
275
- jax.nn.tanh,
306
+ (jax.nn.tanh,),
276
307
  (eqx.nn.Linear, 20, 1)
277
308
  )`.
278
309
  input_transform
@@ -343,10 +374,13 @@ class HyperPINN(PINN):
343
374
  (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
344
375
  )
345
376
  else:
346
- eqx_list_hyper = (
347
- eqx_list_hyper[:-2]
348
- + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
349
- + eqx_list_hyper[-1]
377
+ eqx_list_hyper = cast(
378
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...],
379
+ (
380
+ eqx_list_hyper[:-2]
381
+ + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
382
+ + eqx_list_hyper[-1]
383
+ ),
350
384
  )
351
385
  if len(eqx_list_hyper[0]) > 1:
352
386
  eqx_list_hyper = (
@@ -357,21 +391,24 @@ class HyperPINN(PINN):
357
391
  ),
358
392
  ) + eqx_list_hyper[1:]
359
393
  else:
360
- eqx_list_hyper = (
361
- eqx_list_hyper[0]
362
- + (
363
- (
364
- (eqx_list_hyper[1][0],)
365
- + (hypernet_input_size,)
366
- + (eqx_list_hyper[1][2],)
367
- ),
368
- )
369
- + eqx_list_hyper[2:]
394
+ eqx_list_hyper = cast(
395
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...],
396
+ (
397
+ eqx_list_hyper[0]
398
+ + (
399
+ (
400
+ (eqx_list_hyper[1][0],)
401
+ + (hypernet_input_size,)
402
+ + (eqx_list_hyper[1][2],) # type: ignore because we suppose that the second element of tuple is nec.of length > 1 since we expect smth like eqx.nn.Linear
403
+ ),
404
+ )
405
+ + eqx_list_hyper[2:]
406
+ ),
370
407
  )
371
408
  key, subkey = jax.random.split(key, 2)
372
409
  # with warnings.catch_warnings():
373
410
  # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
374
- eqx_hyper_network = MLP(key=subkey, eqx_list=eqx_list_hyper)
411
+ eqx_hyper_network = cast(MLP, MLP(key=subkey, eqx_list=eqx_list_hyper))
375
412
 
376
413
  ### End of finetuning the hypernetwork architecture
377
414
 
@@ -386,10 +423,10 @@ class HyperPINN(PINN):
386
423
  hyperpinn = cls(
387
424
  eqx_network=eqx_network,
388
425
  eqx_hyper_network=eqx_hyper_network,
389
- slice_solution=slice_solution,
426
+ slice_solution=slice_solution, # type: ignore
390
427
  eq_type=eq_type,
391
- input_transform=input_transform,
392
- output_transform=output_transform,
428
+ input_transform=input_transform, # type: ignore
429
+ output_transform=output_transform, # type: ignore
393
430
  hyperparams=hyperparams,
394
431
  hypernet_input_size=hypernet_input_size,
395
432
  filter_spec=filter_spec,