jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/_operators.py CHANGED
@@ -2,24 +2,42 @@
2
2
  Implements diverse operators for dynamic losses
3
3
  """
4
4
 
5
- from typing import Literal
5
+ from __future__ import (
6
+ annotations,
7
+ )
8
+
9
+ from typing import Literal, cast, Callable
6
10
 
7
11
  import jax
8
12
  import jax.numpy as jnp
9
13
  from jax import grad
10
- import equinox as eqx
11
14
  from jaxtyping import Float, Array
12
- from jinns.utils._pinn import PINN
13
- from jinns.utils._spinn import SPINN
14
15
  from jinns.parameters._params import Params
16
+ from jinns.nn._abstract_pinn import AbstractPINN
17
+
18
+
19
+ def _get_eq_type(
20
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
21
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
22
+ ) -> Literal["nonstatio_PDE", "statio_PDE"]:
23
+ """
24
+ But we filter out ODE from eq_type because we only have operators that does
25
+ not work with ODEs so far
26
+ """
27
+ if isinstance(u, AbstractPINN):
28
+ assert u.eq_type != "ODE", "Cannot compute the operator for ODE PINNs"
29
+ return u.eq_type
30
+ if eq_type is None:
31
+ raise ValueError("eq_type could not be set!")
32
+ return eq_type
15
33
 
16
34
 
17
35
  def divergence_rev(
18
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
19
- u: eqx.Module,
20
- params: Params,
21
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
22
- ) -> float:
36
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
37
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
38
+ params: Params[Array],
39
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
40
+ ) -> Float[Array, " "]:
23
41
  r"""
24
42
  Compute the divergence of a vector field $\mathbf{u}$, i.e.,
25
43
  $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathrm{inputs})$ with $\mathbf{u}$ a vector
@@ -43,13 +61,7 @@ def divergence_rev(
43
61
  can know that by inspecting the `u` argument (PINN object). But if `u` is
44
62
  a function, we must set this attribute.
45
63
  """
46
-
47
- try:
48
- eq_type = u.eq_type
49
- except AttributeError:
50
- pass # use the value passed as argument
51
- if eq_type is None:
52
- raise ValueError("eq_type could not be set!")
64
+ eq_type = _get_eq_type(u, eq_type)
53
65
 
54
66
  def scan_fun(_, i):
55
67
  if eq_type == "nonstatio_PDE":
@@ -72,11 +84,11 @@ def divergence_rev(
72
84
 
73
85
 
74
86
  def divergence_fwd(
75
- inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
76
- u: eqx.Module,
77
- params: Params,
78
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
79
- ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
87
+ inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
88
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
89
+ params: Params[Array],
90
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
91
+ ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
80
92
  r"""
81
93
  Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
82
94
  $\nabla_\mathbf{x} \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
@@ -105,13 +117,7 @@ def divergence_fwd(
105
117
  can know that by inspecting the `u` argument (PINN object). But if `u` is
106
118
  a function, we must set this attribute.
107
119
  """
108
-
109
- try:
110
- eq_type = u.eq_type
111
- except AttributeError:
112
- pass # use the value passed as argument
113
- if eq_type is None:
114
- raise ValueError("eq_type could not be set!")
120
+ eq_type = _get_eq_type(u, eq_type)
115
121
 
116
122
  def scan_fun(_, i):
117
123
  if eq_type == "nonstatio_PDE":
@@ -144,12 +150,12 @@ def divergence_fwd(
144
150
 
145
151
 
146
152
  def laplacian_rev(
147
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
148
- u: eqx.Module,
149
- params: Params,
153
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
154
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
155
+ params: Params[Array],
150
156
  method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
151
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
152
- ) -> float:
157
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
158
+ ) -> Float[Array, " "]:
153
159
  r"""
154
160
  Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
155
161
  to $\mathbb{R}$ or from $\mathbb{R}^{1+d}$ to $\mathbb{R}$, i.e., this
@@ -182,13 +188,7 @@ def laplacian_rev(
182
188
  can know that by inspecting the `u` argument (PINN object). But if `u` is
183
189
  a function, we must set this attribute.
184
190
  """
185
-
186
- try:
187
- eq_type = u.eq_type
188
- except AttributeError:
189
- pass # use the value passed as argument
190
- if eq_type is None:
191
- raise ValueError("eq_type could not be set!")
191
+ eq_type = _get_eq_type(u, eq_type)
192
192
 
193
193
  if method == "trace_hessian_x":
194
194
  # NOTE we afford a concatenate here to avoid computing Hessian elements for
@@ -228,16 +228,12 @@ def laplacian_rev(
228
228
  if eq_type == "nonstatio_PDE":
229
229
  d2u_dxi2 = grad(
230
230
  lambda inputs: grad(u_)(inputs)[1 + i],
231
- )(
232
- inputs
233
- )[1 + i]
231
+ )(inputs)[1 + i]
234
232
  else:
235
233
  d2u_dxi2 = grad(
236
234
  lambda inputs: grad(u_, 0)(inputs)[i],
237
235
  0,
238
- )(
239
- inputs
240
- )[i]
236
+ )(inputs)[i]
241
237
  return _, d2u_dxi2
242
238
 
243
239
  if eq_type == "nonstatio_PDE":
@@ -253,12 +249,12 @@ def laplacian_rev(
253
249
 
254
250
 
255
251
  def laplacian_fwd(
256
- inputs: Float[Array, "batch_size 1+dim"] | Float[Array, "batch_size dim"],
257
- u: eqx.Module,
258
- params: Params,
252
+ inputs: Float[Array, " batch_size 1+dim"] | Float[Array, " batch_size dim"],
253
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
254
+ params: Params[Array],
259
255
  method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
260
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] = None,
261
- ) -> Float[Array, "batch_size * (1+dim) 1"] | Float[Array, "batch_size * (dim) 1"]:
256
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
257
+ ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
262
258
  r"""
263
259
  Compute the Laplacian of a **batched** scalar field $u$
264
260
  from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$ or
@@ -301,13 +297,7 @@ def laplacian_fwd(
301
297
  can know that by inspecting the `u` argument (PINN object). But if `u` is
302
298
  a function, we must set this attribute.
303
299
  """
304
-
305
- try:
306
- eq_type = u.eq_type
307
- except AttributeError:
308
- pass # use the value passed as argument
309
- if eq_type is None:
310
- raise ValueError("eq_type could not be set!")
300
+ eq_type = _get_eq_type(u, eq_type)
311
301
 
312
302
  if method == "loop":
313
303
 
@@ -400,11 +390,12 @@ def laplacian_fwd(
400
390
 
401
391
 
402
392
  def vectorial_laplacian_rev(
403
- inputs: Float[Array, "dim"] | Float[Array, "1+dim"],
404
- u: eqx.Module,
405
- params: Params,
406
- dim_out: int = None,
407
- ) -> Float[Array, "dim_out"]:
393
+ inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
394
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
395
+ params: Params[Array],
396
+ dim_out: int | None = None,
397
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
398
+ ) -> Float[Array, " dim_out"]:
408
399
  r"""
409
400
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
410
401
  $\mathbb{R}^d$ to $\mathbb{R}^n$ or from $\mathbb{R}^{1+d}$ to
@@ -428,7 +419,12 @@ def vectorial_laplacian_rev(
428
419
  dim_out
429
420
  Dimension of the vector $\mathbf{u}(\mathrm{inputs})$. This needs to be
430
421
  provided if it is different than that of $\mathrm{inputs}$.
422
+ eq_type
423
+ whether we consider a stationary or non stationary PINN. Most often we
424
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
425
+ a function, we must set this attribute.
431
426
  """
427
+ eq_type = _get_eq_type(u, eq_type)
432
428
  if dim_out is None:
433
429
  dim_out = inputs.shape[0]
434
430
 
@@ -437,7 +433,9 @@ def vectorial_laplacian_rev(
437
433
  # each of these components
438
434
  # Note the jnp.expand_dims call
439
435
  uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[j], axis=-1)
440
- lap_on_j = laplacian_rev(inputs, uj, params, eq_type=u.eq_type)
436
+ lap_on_j = laplacian_rev(
437
+ inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
438
+ )
441
439
 
442
440
  return _, lap_on_j
443
441
 
@@ -446,11 +444,12 @@ def vectorial_laplacian_rev(
446
444
 
447
445
 
448
446
  def vectorial_laplacian_fwd(
449
- inputs: Float[Array, "batch_size dim"] | Float[Array, "batch_size 1+dim"],
450
- u: eqx.Module,
451
- params: Params,
452
- dim_out: int = None,
453
- ) -> Float[Array, "batch_size * (1+dim) n"] | Float[Array, "batch_size * (dim) n"]:
447
+ inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
448
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
449
+ params: Params[Array],
450
+ dim_out: int | None = None,
451
+ eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
452
+ ) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
454
453
  r"""
455
454
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
456
455
  `u` is a SPINN, in this case, it corresponds to a vector
@@ -476,7 +475,12 @@ def vectorial_laplacian_fwd(
476
475
  dim_out
477
476
  the value of the output dimension ($n$ in the formula above). Must be
478
477
  set if different from $d$.
478
+ eq_type
479
+ whether we consider a stationary or non stationary PINN. Most often we
480
+ can know that by inspecting the `u` argument (PINN object). But if `u` is
481
+ a function, we must set this attribute.
479
482
  """
483
+ eq_type = _get_eq_type(u, eq_type)
480
484
  if dim_out is None:
481
485
  dim_out = inputs.shape[0]
482
486
 
@@ -485,7 +489,9 @@ def vectorial_laplacian_fwd(
485
489
  # each of these components
486
490
  # Note the expand_dims
487
491
  uj = lambda inputs, params: jnp.expand_dims(u(inputs, params)[..., j], axis=-1)
488
- lap_on_j = laplacian_fwd(inputs, uj, params, eq_type=u.eq_type)
492
+ lap_on_j = laplacian_fwd(
493
+ inputs, cast(AbstractPINN, uj), params, eq_type=eq_type
494
+ )
489
495
 
490
496
  return _, lap_on_j
491
497
 
@@ -494,8 +500,10 @@ def vectorial_laplacian_fwd(
494
500
 
495
501
 
496
502
  def _u_dot_nabla_times_u_rev(
497
- x: Float[Array, "2"], u: eqx.Module, params: Params
498
- ) -> Float[Array, "2"]:
503
+ x: Float[Array, " 2"],
504
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
505
+ params: Params[Array],
506
+ ) -> Float[Array, " 2"]:
499
507
  r"""
500
508
  Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
501
509
  $\mathbf{x}$ of arbitrary
@@ -524,10 +532,10 @@ def _u_dot_nabla_times_u_rev(
524
532
 
525
533
 
526
534
  def _u_dot_nabla_times_u_fwd(
527
- x: Float[Array, "batch_size 2"],
528
- u: eqx.Module,
529
- params: Params,
530
- ) -> Float[Array, "batch_size batch_size 2"]:
535
+ x: Float[Array, " batch_size 2"],
536
+ u: AbstractPINN | Callable[[Array, Params[Array]], Array],
537
+ params: Params[Array],
538
+ ) -> Float[Array, " batch_size batch_size 2"]:
531
539
  r"""
532
540
  Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
533
541
  :math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
jinns/nn/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ from ._save_load import save_pinn, load_pinn
2
+ from ._abstract_pinn import AbstractPINN
3
+ from ._pinn import PINN
4
+ from ._spinn import SPINN
5
+ from ._mlp import PINN_MLP, MLP
6
+ from ._spinn_mlp import SPINN_MLP, SMLP
7
+ from ._hyperpinn import HyperPINN
8
+ from ._ppinn import PPINN_MLP
9
+
10
+ __all__ = [
11
+ "save_pinn",
12
+ "load_pinn",
13
+ "AbstractPINN",
14
+ "PINN",
15
+ "SPINN",
16
+ "PINN_MLP",
17
+ "MLP",
18
+ "SPINN_MLP",
19
+ "SMLP",
20
+ "HyperPINN",
21
+ "PPINN_MLP",
22
+ ]
@@ -0,0 +1,22 @@
1
+ import abc
2
+ from typing import Literal, Any
3
+ from jaxtyping import Array
4
+ import equinox as eqx
5
+
6
+ from jinns.nn._utils import _PyTree_to_Params
7
+ from jinns.parameters._params import Params
8
+
9
+
10
+ class AbstractPINN(eqx.Module):
11
+ """
12
+ Basically just a way to add a __call__ to an eqx.Module.
13
+ The way to go for correct type hints apparently
14
+ https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
15
+ """
16
+
17
+ eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
18
+
19
+ @abc.abstractmethod
20
+ @_PyTree_to_Params
21
+ def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
22
+ pass