jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/loss/_operators.py CHANGED
@@ -5,15 +5,20 @@ Implements diverse operators for dynamic losses
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  from jax import grad
8
+ import equinox as eqx
9
+ from jaxtyping import Float, Array
8
10
  from jinns.utils._pinn import PINN
9
11
  from jinns.utils._spinn import SPINN
12
+ from jinns.parameters._params import Params
10
13
 
11
14
 
12
- def _div_rev(t, x, u, params):
15
+ def _div_rev(
16
+ t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
17
+ ) -> float:
13
18
  r"""
14
- Compute the divergence of a vector field :math:`\mathbf{u}`, i.e.,
15
- :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
16
- field from :math:`\mathbb{R}^d` to :math:`\mathbb{R}^d`.
19
+ Compute the divergence of a vector field $\mathbf{u}$, i.e.,
20
+ $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
21
+ field from $\mathbb{R}^d$ to $\mathbb{R}^d$.
17
22
  The computation is done using backward AD
18
23
  """
19
24
 
@@ -28,15 +33,21 @@ def _div_rev(t, x, u, params):
28
33
  return jnp.sum(accu)
29
34
 
30
35
 
31
- def _div_fwd(t, x, u, params):
36
+ def _div_fwd(
37
+ t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
38
+ ) -> float:
32
39
  r"""
33
- Compute the divergence of a **batched** vector field :math:`\mathbf{u}`, i.e.,
34
- :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
35
- field from :math:`\mathbb{R}^{b \times d}` to :math:`\mathbb{R}^{b \times b
36
- \times d}`. The result is then in :math:`\mathbb{R}^{b\times b}`.
40
+ Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
41
+ $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
42
+ field from $\mathbb{R}^{b \times d}$ to $\mathbb{R}^{b \times b
43
+ \times d}$. The result is then in $\mathbb{R}^{b\times b}$.
37
44
  Because of the embedding that happens in SPINNs the
38
- computation is most efficient with forward AD. This is the idea behind Separable PINNs.
39
- This function is to be used in the context of SPINNs only.
45
+ computation is most efficient with forward AD. This is the idea behind
46
+ Separable PINNs.
47
+
48
+ !!! warning "Warning"
49
+
50
+ This function is to be used in the context of SPINNs only.
40
51
  """
41
52
 
42
53
  def scan_fun(_, i):
@@ -55,11 +66,13 @@ def _div_fwd(t, x, u, params):
55
66
  return jnp.sum(accu, axis=0)
56
67
 
57
68
 
58
- def _laplacian_rev(t, x, u, params):
69
+ def _laplacian_rev(
70
+ t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
71
+ ) -> float:
59
72
  r"""
60
- Compute the Laplacian of a scalar field :math:`u` (from :math:`\mathbb{R}^d`
61
- to :math:`\mathbb{R}`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
62
- :math:`\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})`.
73
+ Compute the Laplacian of a scalar field $u$ (from $\mathbb{R}^d$
74
+ to $\mathbb{R}$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
75
+ $\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})$.
63
76
  The computation is done using backward AD.
64
77
  """
65
78
 
@@ -98,15 +111,24 @@ def _laplacian_rev(t, x, u, params):
98
111
  # return jnp.sum(trace_hessian)
99
112
 
100
113
 
101
- def _laplacian_fwd(t, x, u, params):
114
+ def _laplacian_fwd(
115
+ t: Float[Array, "batch_size 1"],
116
+ x: Float[Array, "batch_size dimension"],
117
+ u: eqx.Module,
118
+ params: Params,
119
+ ) -> Float[Array, "batch_size batch_size"]:
102
120
  r"""
103
- Compute the Laplacian of a **batched** scalar field :math:`u`
104
- (from :math:`\mathbb{R}^{b\times d}` to :math:`\mathbb{R}^{b\times b}`)
105
- for :math:`\mathbf{x}` of arbitrary dimension :math:`d` with batch
106
- dimension :math:`b`.
121
+ Compute the Laplacian of a **batched** scalar field $u$
122
+ (from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$)
123
+ for $\mathbf{x}$ of arbitrary dimension $d$ with batch
124
+ dimension $b$.
107
125
  Because of the embedding that happens in SPINNs the
108
- computation is most efficient with forward AD. This is the idea behind Separable PINNs.
109
- This function is to be used in the context of SPINNs only.
126
+ computation is most efficient with forward AD. This is the idea behind
127
+ Separable PINNs.
128
+
129
+ !!! warning "Warning"
130
+
131
+ This function is to be used in the context of SPINNs only.
110
132
  """
111
133
 
112
134
  def scan_fun(_, i):
@@ -134,22 +156,30 @@ def _laplacian_fwd(t, x, u, params):
134
156
  return jnp.sum(trace_hessian, axis=0)
135
157
 
136
158
 
137
- def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
159
+ def _vectorial_laplacian(
160
+ t: Float[Array, "1"] | Float[Array, "batch_size 1"],
161
+ x: Float[Array, "dimension_in"] | Float[Array, "batch_size dimension"],
162
+ u: eqx.Module,
163
+ params: Params,
164
+ u_vec_ndim: int = None,
165
+ ) -> (
166
+ Float[Array, "dimension_out"] | Float[Array, "batch_size batch_size dimension_out"]
167
+ ):
138
168
  r"""
139
- Compute the vectorial Laplacian of a vector field :math:`\mathbf{u}` (from
140
- :math:`\mathbb{R}^d`
141
- to :math:`\mathbb{R}^n`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
142
- :math:`\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
143
- \mathbf{u}(\mathbf{x})`.
169
+ Compute the vectorial Laplacian of a vector field $\mathbf{u}$ (from
170
+ $\mathbb{R}^d$
171
+ to $\mathbb{R}^n$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
172
+ $\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
173
+ \mathbf{u}(\mathbf{x})$.
144
174
 
145
175
  **Note:** We need to provide `u_vec_ndim` the dimension of the vector
146
- :math:`\mathbf{u}(\mathbf{x})` if it is different than that of
147
- :math:`\mathbf{x}`.
176
+ $\mathbf{u}(\mathbf{x})$ if it is different than that of
177
+ $\mathbf{x}$.
148
178
 
149
179
  **Note:** `u` can be a SPINN, in this case, it corresponds to a vector
150
- field from (from :math:`\mathbb{R}^{b\times d}` to
151
- :math:`\mathbb{R}^{b\times b\times n}`) and forward mode AD is used.
152
- Technically, the return is of dimension :math:`n\times b \times b`.
180
+ field from (from $\mathbb{R}^{b\times d}$ to
181
+ $\mathbb{R}^{b\times b\times n}$) and forward mode AD is used.
182
+ Technically, the return is of dimension $n\times b \times b$.
153
183
  """
154
184
  if u_vec_ndim is None:
155
185
  u_vec_ndim = x.shape[0]
@@ -172,6 +202,8 @@ def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
172
202
  u(t, x, params)[..., j], axis=-1
173
203
  )
174
204
  lap_on_j = _laplacian_fwd(t, x, uj, params)
205
+ else:
206
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
175
207
 
176
208
  return _, lap_on_j
177
209
 
@@ -179,12 +211,14 @@ def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
179
211
  return vec_lap
180
212
 
181
213
 
182
- def _u_dot_nabla_times_u_rev(t, x, u, params):
214
+ def _u_dot_nabla_times_u_rev(
215
+ t: Float[Array, "1"], x: Float[Array, "2"], u: eqx.Module, params: Params
216
+ ) -> Float[Array, "2"]:
183
217
  r"""
184
- Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
185
- :math:`\mathbf{x}` of arbitrary
186
- dimension. :math:`\mathbf{u}` is a vector field from :math:`\mathbb{R}^n`
187
- to :math:`\mathbb{R}^n`. **Currently for** `x.ndim=2` **only**.
218
+ Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
219
+ $\mathbf{x}$ of arbitrary
220
+ dimension. $\mathbf{u}$ is a vector field from $\mathbb{R}^n$
221
+ to $\mathbb{R}^n$. **Currently for** `x.ndim=2` **only**.
188
222
  The computation is done using backward AD.
189
223
  We do not use loops but code explicitly the expression to avoid
190
224
  computing twice some terms
@@ -224,7 +258,12 @@ def _u_dot_nabla_times_u_rev(t, x, u, params):
224
258
  raise NotImplementedError("x.ndim must be 2")
225
259
 
226
260
 
227
- def _u_dot_nabla_times_u_fwd(t, x, u, params):
261
+ def _u_dot_nabla_times_u_fwd(
262
+ t: Float[Array, "batch_size 1"],
263
+ x: Float[Array, "batch_size 2"],
264
+ u: eqx.Module,
265
+ params: Params,
266
+ ) -> Float[Array, "batch_size batch_size 2"]:
228
267
  r"""
229
268
  Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
230
269
  :math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
@@ -264,36 +303,3 @@ def _u_dot_nabla_times_u_fwd(t, x, u, params):
264
303
  axis=-1,
265
304
  )
266
305
  raise NotImplementedError("x.ndim must be 2")
267
-
268
-
269
- def _sobolev(u, m, statio=True):
270
- r"""
271
- Compute the Sobolev regularization of order :math:`m`
272
- of a scalar field :math:`u` (from :math:`\mathbb{R}^{d}` to :math:`\mathbb{R}`)
273
- for :math:`\mathbf{x}` of arbitrary dimension :math:`d`, i.e.,
274
- :math:`\frac{1}{n_l}\sum_{l=1}^{n_l}\sum_{|\alpha|=1}^{m+1} ||\partial^{\alpha} u(x_l)||_2^2` where
275
- :math:`m\geq\max(d_1 // 2, K)` with :math:`K` the order of the differential
276
- operator.
277
-
278
- This regularization is proposed in *Convergence and error analysis of
279
- PINNs*, Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
280
- """
281
-
282
- def jac_recursive(u, order, start):
283
- # Compute the derivative of order `start`
284
- if order == 0:
285
- return u
286
- if start == 0:
287
- return jac_recursive(jax.jacrev(u), order - 1, start + 1)
288
- return jac_recursive(jax.jacfwd(u), order - 1, start + 1)
289
-
290
- if statio:
291
- return lambda x, params: jnp.sum(
292
- jac_recursive(lambda x: u(x, params), m + 1, 0)(x) ** 2
293
- )
294
- return lambda t, x, params: jnp.sum(
295
- jac_recursive(lambda tx: u(tx[0:1], tx[1:], params), m + 1, 0)(
296
- jnp.concatenate([t, x], axis=0)
297
- )
298
- ** 2
299
- )
@@ -0,0 +1,6 @@
1
+ from ._params import Params, ParamsDict
2
+ from ._derivative_keys import (
3
+ DerivativeKeysODE,
4
+ DerivativeKeysPDEStatio,
5
+ DerivativeKeysPDENonStatio,
6
+ )