jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
@@ -2,58 +2,82 @@
2
2
  Formalize the loss weights data structure
3
3
  """
4
4
 
5
- from typing import Dict
6
- from jaxtyping import Array, Float
5
+ from __future__ import annotations
6
+ from dataclasses import fields
7
+
8
+ from jaxtyping import Array
9
+ import jax.numpy as jnp
7
10
  import equinox as eqx
8
11
 
9
12
 
10
- class LossWeightsODE(eqx.Module):
13
+ def lw_converter(x):
14
+ if x is None:
15
+ return x
16
+ else:
17
+ return jnp.asarray(x)
11
18
 
12
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
13
- initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
14
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
15
19
 
20
+ class AbstractLossWeights(eqx.Module):
21
+ """
22
+ An abstract class, currently only useful for type hints
16
23
 
17
- class LossWeightsODEDict(eqx.Module):
24
+ TODO in the future maybe loss weights could be subclasses of
25
+ XDEComponentsAbstract?
26
+ """
18
27
 
19
- dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
20
- initial_condition: Dict[str, Array | Float | None] = eqx.field(
21
- kw_only=True, default=None
28
+ def items(self):
29
+ """
30
+ For the dataclass to be iterated like a dictionary.
31
+ Practical and retrocompatible with old code when loss components were
32
+ dictionaries
33
+ """
34
+ return {
35
+ field.name: getattr(self, field.name)
36
+ for field in fields(self)
37
+ if getattr(self, field.name) is not None
38
+ }.items()
39
+
40
+
41
+ class LossWeightsODE(AbstractLossWeights):
42
+ dyn_loss: Array | float | None = eqx.field(
43
+ kw_only=True, default=None, converter=lw_converter
22
44
  )
23
- observations: Dict[str, Array | Float | None] = eqx.field(
24
- kw_only=True, default=None
45
+ initial_condition: Array | float | None = eqx.field(
46
+ kw_only=True, default=None, converter=lw_converter
47
+ )
48
+ observations: Array | float | None = eqx.field(
49
+ kw_only=True, default=None, converter=lw_converter
25
50
  )
26
51
 
27
52
 
28
- class LossWeightsPDEStatio(eqx.Module):
29
-
30
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
31
- norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
32
- boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
33
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
34
-
35
-
36
- class LossWeightsPDENonStatio(eqx.Module):
37
-
38
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
39
- norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
40
- boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
41
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
42
- initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
43
-
53
+ class LossWeightsPDEStatio(AbstractLossWeights):
54
+ dyn_loss: Array | float | None = eqx.field(
55
+ kw_only=True, default=None, converter=lw_converter
56
+ )
57
+ norm_loss: Array | float | None = eqx.field(
58
+ kw_only=True, default=None, converter=lw_converter
59
+ )
60
+ boundary_loss: Array | float | None = eqx.field(
61
+ kw_only=True, default=None, converter=lw_converter
62
+ )
63
+ observations: Array | float | None = eqx.field(
64
+ kw_only=True, default=None, converter=lw_converter
65
+ )
44
66
 
45
- class LossWeightsPDEDict(eqx.Module):
46
- """
47
- Only one type of LossWeights data structure for the SystemLossPDE:
48
- Include the initial condition always for the code to be more generic
49
- """
50
67
 
51
- dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
52
- norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
53
- boundary_loss: Dict[str, Array | Float | None] = eqx.field(
54
- kw_only=True, default=1.0
68
+ class LossWeightsPDENonStatio(AbstractLossWeights):
69
+ dyn_loss: Array | float | None = eqx.field(
70
+ kw_only=True, default=None, converter=lw_converter
71
+ )
72
+ norm_loss: Array | float | None = eqx.field(
73
+ kw_only=True, default=None, converter=lw_converter
74
+ )
75
+ boundary_loss: Array | float | None = eqx.field(
76
+ kw_only=True, default=None, converter=lw_converter
77
+ )
78
+ observations: Array | float | None = eqx.field(
79
+ kw_only=True, default=None, converter=lw_converter
55
80
  )
56
- observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
57
- initial_condition: Dict[str, Array | Float | None] = eqx.field(
58
- kw_only=True, default=1.0
81
+ initial_condition: Array | float | None = eqx.field(
82
+ kw_only=True, default=None, converter=lw_converter
59
83
  )
jinns/loss/_operators.py CHANGED
@@ -2,22 +2,42 @@
2
2
  Implements diverse operators for dynamic losses
3
3
  """
4
4
 
5
- from typing import Literal
5
+ from __future__ import (
6
+ annotations,
7
+ )
8
+
9
+ from typing import Literal, cast, Callable
6
10
 
7
11
  import jax
8
12
  import jax.numpy as jnp
9
13
  from jax import grad
10
- import equinox as eqx
11
14
  from jaxtyping import Float, Array
12
15
  from jinns.parameters._params import Params
16
+ from jinns.nn._abstract_pinn import AbstractPINN
17
+
18
+
19
+ def _get_eq_type(
20
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
21
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
22
+ ) -> Literal["nonstatio_PDE", "statio_PDE"]:
23
+ """
24
+ But we filter out ODE from eq_type because we only have operators that does
25
+ not work with ODEs so far
26
+ """
27
+ if isinstance(u, AbstractPINN):
28
+ assert u.eq_type != "ODE", "Cannot compute the operator for ODE PINNs"
29
+ return u.eq_type
30
+ if eq_type is None:
31
+ raise ValueError("eq_type could not be set!")
32
+ return eq_type
13
33
 
14
34
 
15
35
  def divergence_rev(
16
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
17
- u: eqx.Module,
18
- params: Params,
19
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
20
- ) -> float:
36
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
37
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
38
+ params: Params[Array],
39
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
40
+ ) -> Float[Array, " "]:
21
41
  r"""
22
42
  Compute the divergence of a vector field $\mathbf{u}$, i.e.,
23
43
  $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
@@ -41,13 +61,7 @@ def divergence_rev(
41
61
  can know that by inspecting the `u` argument (PINN object). But if `u` is
42
62
  a function, we must set this attribute.
43
63
  """
44
-
45
- try:
46
- eq_type = u.eq_type
47
- except AttributeError:
48
- pass # use the value passed as argument
49
- if eq_type is None:
50
- raise ValueError("eq_type could not be set!")
64
+ eq_type = _get_eq_type(u, eq_type)
51
65
 
52
66
  def scan_fun(_, i):
53
67
  if eq_type == "nonstatio_PDE":
@@ -70,11 +84,11 @@ def divergence_rev(
70
84
 
71
85
 
72
86
  def divergence_fwd(
73
- inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
74
- u: eqx.Module,
75
- params: Params,
76
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
77
- ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
87
+ inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
88
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
89
+ params: Params[Array],
90
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
91
+ ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
78
92
  r"""
79
93
  Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
80
94
  $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
@@ -103,13 +117,7 @@ def divergence_fwd(
103
117
  can know that by inspecting the `u` argument (PINN object). But if `u` is
104
118
  a function, we must set this attribute.
105
119
  """
106
-
107
- try:
108
- eq_type = u.eq_type
109
- except AttributeError:
110
- pass # use the value passed as argument
111
- if eq_type is None:
112
- raise ValueError("eq_type could not be set!")
120
+ eq_type = _get_eq_type(u, eq_type)
113
121
 
114
122
  def scan_fun(_, i):
115
123
  if eq_type == "nonstatio_PDE":
@@ -142,12 +150,12 @@ def divergence_fwd(
142
150
 
143
151
 
144
152
  def laplacian_rev(
145
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
146
- u: eqx.Module,
147
- params: Params,
153
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
154
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
155
+ params: Params[Array],
148
156
  method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
149
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
150
- ) -> float:
157
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
158
+ ) -> Float[Array, " "]:
151
159
  r"""
152
160
  Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
153
161
  to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
@@ -180,13 +188,7 @@ def laplacian_rev(
180
188
  can know that by inspecting the `u` argument (PINN object). But if `u` is
181
189
  a function, we must set this attribute.
182
190
  """
183
-
184
- try:
185
- eq_type = u.eq_type
186
- except AttributeError:
187
- pass # use the value passed as argument
188
- if eq_type is None:
189
- raise ValueError("eq_type could not be set!")
191
+ eq_type = _get_eq_type(u, eq_type)
190
192
 
191
193
  if method == "trace_hessian_x":
192
194
  # NOTE we afford a concatenate here to avoid computing Hessian elements for
@@ -226,16 +228,12 @@ def laplacian_rev(
226
228
  if eq_type == "nonstatio_PDE":
227
229
  d2u_dxi2 = grad(
228
230
  lambda inputs: grad(u_)(inputs)[1 + i],
229
- )(
230
- inputs
231
- )[1 + i]
231
+ )(inputs)[1 + i]
232
232
  else:
233
233
  d2u_dxi2 = grad(
234
234
  lambda inputs: grad(u_, 0)(inputs)[i],
235
235
  0,
236
- )(
237
- inputs
238
- )[i]
236
+ )(inputs)[i]
239
237
  return _, d2u_dxi2
240
238
 
241
239
  if eq_type == "nonstatio_PDE":
@@ -251,12 +249,12 @@ def laplacian_rev(
251
249
 
252
250
 
253
251
  def laplacian_fwd(
254
- inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
255
- u: eqx.Module,
256
- params: Params,
252
+ inputs: Float[Array, " batch_size 1+dim"] | Float[Array, " batch_size dim"],
253
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
254
+ params: Params[Array],
257
255
  method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
258
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
259
- ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
256
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
257
+ ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
260
258
  r"""
261
259
  Compute the Laplacian of a **batched** scalar field $u$
262
260
  from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
@@ -299,13 +297,7 @@ def laplacian_fwd(
299
297
  can know that by inspecting the `u` argument (PINN object). But if `u` is
300
298
  a function, we must set this attribute.
301
299
  """
302
-
303
- try:
304
- eq_type = u.eq_type
305
- except AttributeError:
306
- pass # use the value passed as argument
307
- if eq_type is None:
308
- raise ValueError("eq_type could not be set!")
300
+ eq_type = _get_eq_type(u, eq_type)
309
301
 
310
302
  if method == "loop":
311
303
 
@@ -398,11 +390,12 @@ def laplacian_fwd(
398
390
 
399
391
 
400
392
  def vectorial_laplacian_rev(
401
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
402
- u: eqx.Module,
403
- params: Params,
404
- dim_out: int = None,
405
- ) -> Float[Array, "dim_out"]:
393
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
394
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
395
+ params: Params[Array],
396
+ dim_out: int | None = None,
397
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
398
+ ) -> Float[Array, " dim_out"]:
406
399
  r"""
407
400
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
408
401
  $\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
@@ -426,7 +419,12 @@ def vectorial_laplacian_rev(
426
419
  dim_out
427
420
  Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
428
421
  provided if it is different than that of $\mathrm{inputs}$.
422
+ eq_type
423
+ whether we consider a stationary or non stationary PINN. Most often we
424
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
425
+ a function, we must set this attribute.
429
426
  """
427
+ eq_type = _get_eq_type(u, eq_type)
430
428
  if dim_out is None:
431
429
  dim_out = inputs.shape[0]
432
430
 
@@ -435,7 +433,9 @@ def vectorial_laplacian_rev(
435
433
  # each of these components
436
434
  # Note the jnp.expand_dims call
437
435
  uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
438
- lap_on_j = laplacian_rev(inputs, uj, params, eq_type=u.eq_type)
436
+ lap_on_j = laplacian_rev(
437
+ inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
438
+ )
439
439
 
440
440
  return _, lap_on_j
441
441
 
@@ -444,11 +444,12 @@ def vectorial_laplacian_rev(
444
444
 
445
445
 
446
446
  def vectorial_laplacian_fwd(
447
- inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
448
- u: eqx.Module,
449
- params: Params,
450
- dim_out: int = None,
451
- ) -> Float[Array, "batch_size * (1+dim) n"] | Float[Array, "batch_size * (dim) n"]:
447
+ inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
448
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
449
+ params: Params[Array],
450
+ dim_out: int | None = None,
451
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
452
+ ) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
452
453
  r"""
453
454
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
454
455
  `u` is a SPINN, in this case, it corresponds to a vector
@@ -474,7 +475,12 @@ def vectorial_laplacian_fwd(
474
475
  dim_out
475
476
  the value of the output dimension ($n$ in the formula above). Must be
476
477
  set if different from $d$.
478
+ eq_type
479
+ whether we consider a stationary or non stationary PINN. Most often we
480
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
481
+ a function, we must set this attribute.
477
482
  """
483
+ eq_type = _get_eq_type(u, eq_type)
478
484
  if dim_out is None:
479
485
  dim_out = inputs.shape[0]
480
486
 
@@ -483,7 +489,9 @@ def vectorial_laplacian_fwd(
483
489
  # each of these components
484
490
  # Note the expand_dims
485
491
  uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
486
- lap_on_j = laplacian_fwd(inputs, uj, params, eq_type=u.eq_type)
492
+ lap_on_j = laplacian_fwd(
493
+ inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
494
+ )
487
495
 
488
496
  return _, lap_on_j
489
497
 
@@ -492,8 +500,10 @@ def vectorial_laplacian_fwd(
492
500
 
493
501
 
494
502
  def _u_dot_nabla_times_u_rev(
495
- x: Float[Array, "2"], u: eqx.Module, params: Params
496
- ) -> Float[Array, "2"]:
503
+ x: Float[Array, " 2"],
504
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
505
+ params: Params[Array],
506
+ ) -> Float[Array, " 2"]:
497
507
  r"""
498
508
  Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
499
509
  $\mathbf{x}$ of arbitrary
@@ -522,10 +532,10 @@ def _u_dot_nabla_times_u_rev(
522
532
 
523
533
 
524
534
  def _u_dot_nabla_times_u_fwd(
525
- x: Float[Array, "batch_size 2"],
526
- u: eqx.Module,
527
- params: Params,
528
- ) -> Float[Array, "batch_size batch_size 2"]:
535
+ x: Float[Array, " batch_size 2"],
536
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
537
+ params: Params[Array],
538
+ ) -> Float[Array, " batch_size batch_size 2"]:
529
539
  r"""
530
540
  Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
531
541
  :math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
jinns/nn/__init__.py CHANGED
@@ -1,7 +1,22 @@
1
1
  from ._save_load import save_pinn, load_pinn
2
+ from ._abstract_pinn import AbstractPINN
2
3
  from ._pinn import PINN
3
4
  from ._spinn import SPINN
4
5
  from ._mlp import PINN_MLP, MLP
5
6
  from ._spinn_mlp import SPINN_MLP, SMLP
6
7
  from ._hyperpinn import HyperPINN
7
8
  from ._ppinn import PPINN_MLP
9
+
10
+ __all__ = [
11
+ "save_pinn",
12
+ "load_pinn",
13
+ "AbstractPINN",
14
+ "PINN",
15
+ "SPINN",
16
+ "PINN_MLP",
17
+ "MLP",
18
+ "SPINN_MLP",
19
+ "SMLP",
20
+ "HyperPINN",
21
+ "PPINN_MLP",
22
+ ]
@@ -0,0 +1,22 @@
1
+ import abc
2
+ from typing import Literal, Any
3
+ from jaxtyping import Array
4
+ import equinox as eqx
5
+
6
+ from jinns.nn._utils import _PyTree_to_Params
7
+ from jinns.parameters._params import Params
8
+
9
+
10
+ class AbstractPINN(eqx.Module):
11
+ """
12
+ Basically just a way to add a __call__ to an eqx.Module.
13
+ The way to go for correct type hints apparently
14
+ https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
15
+ """
16
+
17
+ eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
18
+
19
+ @abc.abstractmethod
20
+ @_PyTree_to_Params
21
+ def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
22
+ pass