jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,59 @@
1
+ """
2
+ Formalize the loss weights data structure
3
+ """
4
+
5
+ from typing import Dict
6
+ from jaxtyping import Array, Float
7
+ import equinox as eqx
8
+
9
+
10
+ class LossWeightsODE(eqx.Module):
11
+
12
+ dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
13
+ initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
14
+ observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
15
+
16
+
17
+ class LossWeightsODEDict(eqx.Module):
18
+
19
+ dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
20
+ initial_condition: Dict[str, Array | Float | None] = eqx.field(
21
+ kw_only=True, default=None
22
+ )
23
+ observations: Dict[str, Array | Float | None] = eqx.field(
24
+ kw_only=True, default=None
25
+ )
26
+
27
+
28
+ class LossWeightsPDEStatio(eqx.Module):
29
+
30
+ dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
31
+ norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
32
+ boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
33
+ observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
34
+
35
+
36
+ class LossWeightsPDENonStatio(eqx.Module):
37
+
38
+ dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
39
+ norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
40
+ boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
41
+ observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
42
+ initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
43
+
44
+
45
+ class LossWeightsPDEDict(eqx.Module):
46
+ """
47
+ Only one type of LossWeights data structure for the SystemLossPDE:
48
+ Include the initial condition always for the code to be more generic
49
+ """
50
+
51
+ dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
52
+ norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
53
+ boundary_loss: Dict[str, Array | Float | None] = eqx.field(
54
+ kw_only=True, default=1.0
55
+ )
56
+ observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
57
+ initial_condition: Dict[str, Array | Float | None] = eqx.field(
58
+ kw_only=True, default=1.0
59
+ )
jinns/loss/_operators.py CHANGED
@@ -5,15 +5,20 @@ Implements diverse operators for dynamic losses
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  from jax import grad
8
+ import equinox as eqx
9
+ from jaxtyping import Float, Array
8
10
  from jinns.utils._pinn import PINN
9
11
  from jinns.utils._spinn import SPINN
12
+ from jinns.parameters._params import Params
10
13
 
11
14
 
12
- def _div_rev(t, x, u, params):
15
+ def _div_rev(
16
+ t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
17
+ ) -> float:
13
18
  r"""
14
- Compute the divergence of a vector field :math:`\mathbf{u}`, i.e.,
15
- :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
16
- field from :math:`\mathbb{R}^d` to :math:`\mathbb{R}^d`.
19
+ Compute the divergence of a vector field $\mathbf{u}$, i.e.,
20
+ $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
21
+ field from $\mathbb{R}^d$ to $\mathbb{R}^d$.
17
22
  The computation is done using backward AD
18
23
  """
19
24
 
@@ -28,15 +33,21 @@ def _div_rev(t, x, u, params):
28
33
  return jnp.sum(accu)
29
34
 
30
35
 
31
- def _div_fwd(t, x, u, params):
36
+ def _div_fwd(
37
+ t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
38
+ ) -> float:
32
39
  r"""
33
- Compute the divergence of a **batched** vector field :math:`\mathbf{u}`, i.e.,
34
- :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
35
- field from :math:`\mathbb{R}^{b \times d}` to :math:`\mathbb{R}^{b \times b
36
- \times d}`. The result is then in :math:`\mathbb{R}^{b\times b}`.
40
+ Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
41
+ $\nabla \cdot \mathbf{u}(\mathbf{x})$ with $\mathbf{u}$ a vector
42
+ field from $\mathbb{R}^{b \times d}$ to $\mathbb{R}^{b \times b
43
+ \times d}$. The result is then in $\mathbb{R}^{b\times b}$.
37
44
  Because of the embedding that happens in SPINNs the
38
- computation is most efficient with forward AD. This is the idea behind Separable PINNs.
39
- This function is to be used in the context of SPINNs only.
45
+ computation is most efficient with forward AD. This is the idea behind
46
+ Separable PINNs.
47
+
48
+ !!! warning "Warning"
49
+
50
+ This function is to be used in the context of SPINNs only.
40
51
  """
41
52
 
42
53
  def scan_fun(_, i):
@@ -55,11 +66,13 @@ def _div_fwd(t, x, u, params):
55
66
  return jnp.sum(accu, axis=0)
56
67
 
57
68
 
58
- def _laplacian_rev(t, x, u, params):
69
+ def _laplacian_rev(
70
+ t: Float[Array, "1"], x: Float[Array, "dimension"], u: eqx.Module, params: Params
71
+ ) -> float:
59
72
  r"""
60
- Compute the Laplacian of a scalar field :math:`u` (from :math:`\mathbb{R}^d`
61
- to :math:`\mathbb{R}`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
62
- :math:`\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})`.
73
+ Compute the Laplacian of a scalar field $u$ (from $\mathbb{R}^d$
74
+ to $\mathbb{R}$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
75
+ $\Delta u(\mathbf{x})=\nabla\cdot\nabla u(\mathbf{x})$.
63
76
  The computation is done using backward AD.
64
77
  """
65
78
 
@@ -98,15 +111,24 @@ def _laplacian_rev(t, x, u, params):
98
111
  # return jnp.sum(trace_hessian)
99
112
 
100
113
 
101
- def _laplacian_fwd(t, x, u, params):
114
+ def _laplacian_fwd(
115
+ t: Float[Array, "batch_size 1"],
116
+ x: Float[Array, "batch_size dimension"],
117
+ u: eqx.Module,
118
+ params: Params,
119
+ ) -> Float[Array, "batch_size batch_size"]:
102
120
  r"""
103
- Compute the Laplacian of a **batched** scalar field :math:`u`
104
- (from :math:`\mathbb{R}^{b\times d}` to :math:`\mathbb{R}^{b\times b}`)
105
- for :math:`\mathbf{x}` of arbitrary dimension :math:`d` with batch
106
- dimension :math:`b`.
121
+ Compute the Laplacian of a **batched** scalar field $u$
122
+ (from $\mathbb{R}^{b\times d}$ to $\mathbb{R}^{b\times b}$)
123
+ for $\mathbf{x}$ of arbitrary dimension $d$ with batch
124
+ dimension $b$.
107
125
  Because of the embedding that happens in SPINNs the
108
- computation is most efficient with forward AD. This is the idea behind Separable PINNs.
109
- This function is to be used in the context of SPINNs only.
126
+ computation is most efficient with forward AD. This is the idea behind
127
+ Separable PINNs.
128
+
129
+ !!! warning "Warning"
130
+
131
+ This function is to be used in the context of SPINNs only.
110
132
  """
111
133
 
112
134
  def scan_fun(_, i):
@@ -134,22 +156,30 @@ def _laplacian_fwd(t, x, u, params):
134
156
  return jnp.sum(trace_hessian, axis=0)
135
157
 
136
158
 
137
- def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
159
+ def _vectorial_laplacian(
160
+ t: Float[Array, "1"] | Float[Array, "batch_size 1"],
161
+ x: Float[Array, "dimension_in"] | Float[Array, "batch_size dimension"],
162
+ u: eqx.Module,
163
+ params: Params,
164
+ u_vec_ndim: int = None,
165
+ ) -> (
166
+ Float[Array, "dimension_out"] | Float[Array, "batch_size batch_size dimension_out"]
167
+ ):
138
168
  r"""
139
- Compute the vectorial Laplacian of a vector field :math:`\mathbf{u}` (from
140
- :math:`\mathbb{R}^d`
141
- to :math:`\mathbb{R}^n`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
142
- :math:`\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
143
- \mathbf{u}(\mathbf{x})`.
169
+ Compute the vectorial Laplacian of a vector field $\mathbf{u}$ (from
170
+ $\mathbb{R}^d$
171
+ to $\mathbb{R}^n$) for $\mathbf{x}$ of arbitrary dimension, i.e.,
172
+ $\Delta \mathbf{u}(\mathbf{x})=\nabla\cdot\nabla
173
+ \mathbf{u}(\mathbf{x})$.
144
174
 
145
175
  **Note:** We need to provide `u_vec_ndim` the dimension of the vector
146
- :math:`\mathbf{u}(\mathbf{x})` if it is different than that of
147
- :math:`\mathbf{x}`.
176
+ $\mathbf{u}(\mathbf{x})$ if it is different than that of
177
+ $\mathbf{x}$.
148
178
 
149
179
  **Note:** `u` can be a SPINN, in this case, it corresponds to a vector
150
- field from (from :math:`\mathbb{R}^{b\times d}` to
151
- :math:`\mathbb{R}^{b\times b\times n}`) and forward mode AD is used.
152
- Technically, the return is of dimension :math:`n\times b \times b`.
180
+ field from (from $\mathbb{R}^{b\times d}$ to
181
+ $\mathbb{R}^{b\times b\times n}$) and forward mode AD is used.
182
+ Technically, the return is of dimension $n\times b \times b$.
153
183
  """
154
184
  if u_vec_ndim is None:
155
185
  u_vec_ndim = x.shape[0]
@@ -172,6 +202,8 @@ def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
172
202
  u(t, x, params)[..., j], axis=-1
173
203
  )
174
204
  lap_on_j = _laplacian_fwd(t, x, uj, params)
205
+ else:
206
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
175
207
 
176
208
  return _, lap_on_j
177
209
 
@@ -179,12 +211,14 @@ def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
179
211
  return vec_lap
180
212
 
181
213
 
182
- def _u_dot_nabla_times_u_rev(t, x, u, params):
214
+ def _u_dot_nabla_times_u_rev(
215
+ t: Float[Array, "1"], x: Float[Array, "2"], u: eqx.Module, params: Params
216
+ ) -> Float[Array, "2"]:
183
217
  r"""
184
- Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
185
- :math:`\mathbf{x}` of arbitrary
186
- dimension. :math:`\mathbf{u}` is a vector field from :math:`\mathbb{R}^n`
187
- to :math:`\mathbb{R}^n`. **Currently for** `x.ndim=2` **only**.
218
+ Implement $((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})$ for
219
+ $\mathbf{x}$ of arbitrary
220
+ dimension. $\mathbf{u}$ is a vector field from $\mathbb{R}^n$
221
+ to $\mathbb{R}^n$. **Currently for** `x.ndim=2` **only**.
188
222
  The computation is done using backward AD.
189
223
  We do not use loops but code explicitly the expression to avoid
190
224
  computing twice some terms
@@ -224,7 +258,12 @@ def _u_dot_nabla_times_u_rev(t, x, u, params):
224
258
  raise NotImplementedError("x.ndim must be 2")
225
259
 
226
260
 
227
- def _u_dot_nabla_times_u_fwd(t, x, u, params):
261
+ def _u_dot_nabla_times_u_fwd(
262
+ t: Float[Array, "batch_size 1"],
263
+ x: Float[Array, "batch_size 2"],
264
+ u: eqx.Module,
265
+ params: Params,
266
+ ) -> Float[Array, "batch_size batch_size 2"]:
228
267
  r"""
229
268
  Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
230
269
  :math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
@@ -264,36 +303,3 @@ def _u_dot_nabla_times_u_fwd(t, x, u, params):
264
303
  axis=-1,
265
304
  )
266
305
  raise NotImplementedError("x.ndim must be 2")
267
-
268
-
269
- def _sobolev(u, m, statio=True):
270
- r"""
271
- Compute the Sobolev regularization of order :math:`m`
272
- of a scalar field :math:`u` (from :math:`\mathbb{R}^{d}` to :math:`\mathbb{R}`)
273
- for :math:`\mathbf{x}` of arbitrary dimension :math:`d`, i.e.,
274
- :math:`\frac{1}{n_l}\sum_{l=1}^{n_l}\sum_{|\alpha|=1}^{m+1} ||\partial^{\alpha} u(x_l)||_2^2` where
275
- :math:`m\geq\max(d_1 // 2, K)` with :math:`K` the order of the differential
276
- operator.
277
-
278
- This regularization is proposed in *Convergence and error analysis of
279
- PINNs*, Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
280
- """
281
-
282
- def jac_recursive(u, order, start):
283
- # Compute the derivative of order `start`
284
- if order == 0:
285
- return u
286
- if start == 0:
287
- return jac_recursive(jax.jacrev(u), order - 1, start + 1)
288
- return jac_recursive(jax.jacfwd(u), order - 1, start + 1)
289
-
290
- if statio:
291
- return lambda x, params: jnp.sum(
292
- jac_recursive(lambda x: u(x, params), m + 1, 0)(x) ** 2
293
- )
294
- return lambda t, x, params: jnp.sum(
295
- jac_recursive(lambda tx: u(tx[0:1], tx[1:], params), m + 1, 0)(
296
- jnp.concatenate([t, x], axis=0)
297
- )
298
- ** 2
299
- )
@@ -0,0 +1,6 @@
1
+ from ._params import Params, ParamsDict
2
+ from ._derivative_keys import (
3
+ DerivativeKeysODE,
4
+ DerivativeKeysPDEStatio,
5
+ DerivativeKeysPDENonStatio,
6
+ )
@@ -0,0 +1,94 @@
1
+ """
2
+ Formalize the data structure for the derivative keys
3
+ """
4
+
5
+ from dataclasses import fields
6
+ from typing import Literal
7
+ import jax
8
+ import equinox as eqx
9
+
10
+ from jinns.parameters._params import Params
11
+
12
+
13
+ class DerivativeKeysODE(eqx.Module):
14
+ # we use static = True because all fields are string, hence should be
15
+ # invisible by JAX transforms (JIT, etc.)
16
+ dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
17
+ kw_only=True, default="nn_params", static=True
18
+ )
19
+ observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
20
+ kw_only=True, default="nn_params", static=True
21
+ )
22
+ initial_condition: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
23
+ kw_only=True, default="nn_params", static=True
24
+ )
25
+
26
+
27
+ class DerivativeKeysPDEStatio(eqx.Module):
28
+
29
+ dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
30
+ kw_only=True, default="nn_params", static=True
31
+ )
32
+ observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
33
+ kw_only=True, default="nn_params", static=True
34
+ )
35
+ boundary_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
36
+ kw_only=True, default="nn_params", static=True
37
+ )
38
+ norm_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
39
+ kw_only=True, default="nn_params", static=True
40
+ )
41
+
42
+
43
+ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
44
+
45
+ initial_condition: Literal["nn_params", "eq_params", "both"] = eqx.field(
46
+ kw_only=True, default="nn_params", static=True
47
+ )
48
+
49
+
50
+ def _set_derivatives(params, derivative_keys):
51
+ """
52
+ We construct an eqx.Module with the fields of derivative_keys, each field
53
+ has a copy of the params with appropriate derivatives set
54
+ """
55
+
56
+ def _set_derivatives_(loss_term_derivative):
57
+ if loss_term_derivative == "both":
58
+ return params
59
+ # the next line put a stop_gradient around the fields that do not
60
+ # appear in loss_term_derivative. Currently there are only two possible
61
+ # values nn_params and eq_params but there might be more in the future
62
+ return eqx.tree_at(
63
+ lambda p: tuple(
64
+ getattr(p, f.name)
65
+ for f in fields(Params)
66
+ if f.name != loss_term_derivative
67
+ ),
68
+ params,
69
+ replace_fn=jax.lax.stop_gradient,
70
+ )
71
+
72
+ def _set_derivatives_dict(loss_term_derivative):
73
+ if loss_term_derivative == "both":
74
+ return params
75
+ # the next line put a stop_gradient around the fields that do not
76
+ # appear in loss_term_derivative. Currently there are only two possible
77
+ # values nn_params and eq_params but there might be more in the future
78
+ return {
79
+ k: eqx.tree_at(
80
+ lambda p: tuple(
81
+ getattr(p, f.name)
82
+ for f in fields(Params)
83
+ if f.name != loss_term_derivative
84
+ ),
85
+ params_,
86
+ replace_fn=jax.lax.stop_gradient,
87
+ )
88
+ for k, params_ in params
89
+ }
90
+
91
+ if not isinstance(params, dict):
92
+ return _set_derivatives_(derivative_keys)
93
+ else:
94
+ return _set_derivatives_dict(derivative_keys)
@@ -0,0 +1,115 @@
1
+ """
2
+ Formalize the data structure for the parameters
3
+ """
4
+
5
+ import jax
6
+ import equinox as eqx
7
+ from typing import Dict
8
+ from jaxtyping import Array, PyTree
9
+
10
+
11
+ class Params(eqx.Module):
12
+ """
13
+ The equinox module for the parameters
14
+
15
+ Parameters
16
+ ----------
17
+ nn_params : Pytree
18
+ A PyTree of the non-static part of the PINN eqx.Module, i.e., the
19
+ parameters of the PINN
20
+ eq_params : Dict[str, Array]
21
+ A dictionary of the equation parameters. Keys are the parameter name,
22
+ values are their corresponding value
23
+ """
24
+
25
+ nn_params: PyTree = eqx.field(kw_only=True, default=None)
26
+ eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
27
+
28
+
29
+ class ParamsDict(eqx.Module):
30
+ """
31
+ The equinox module for a dictionnary of parameters with different keys
32
+ corresponding to different equations.
33
+
34
+ Parameters
35
+ ----------
36
+ nn_params : Dict[str, PyTree]
37
+ The neural network's parameters. Most of the time, it will be the
38
+ Array part of an `eqx.Module` obtained by
39
+ `eqx.partition(module, eqx.is_inexact_array)`.
40
+ eq_params : Dict[str, Array]
41
+ A dictionary of the equation parameters. Dict keys are the parameter name as defined your custom loss.
42
+ """
43
+
44
+ nn_params: Dict[str, PyTree] = eqx.field(kw_only=True, default=None)
45
+ eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
46
+
47
+ def extract_params(self, nn_key: str) -> Params:
48
+ """
49
+ Extract the corresponding `nn_params` and `eq_params` for `nn_key` and
50
+ return them in the form of a `Params` object.
51
+ """
52
+ try:
53
+ return Params(
54
+ nn_params=self.nn_params[nn_key],
55
+ eq_params=self.eq_params[nn_key],
56
+ )
57
+ except (KeyError, IndexError) as e:
58
+ return Params(
59
+ nn_params=self.nn_params[nn_key],
60
+ eq_params=self.eq_params,
61
+ )
62
+
63
+
64
+ def _update_eq_params_dict(
65
+ params: Params, param_batch_dict: Dict[str, Array]
66
+ ) -> Params:
67
+ """
68
+ Update params.eq_params with a batch of eq_params for given key(s)
69
+ """
70
+
71
+ # artificially "complete" `param_batch_dict` with None to match `params`
72
+ # PyTree structure
73
+ param_batch_dict_ = param_batch_dict | {
74
+ k: None for k in set(params.eq_params.keys()) - set(param_batch_dict.keys())
75
+ }
76
+
77
+ # Replace at non None leafs
78
+ params = eqx.tree_at(
79
+ lambda p: p.eq_params,
80
+ params,
81
+ jax.tree_util.tree_map(
82
+ lambda p, q: q if q is not None else p,
83
+ params.eq_params,
84
+ param_batch_dict_,
85
+ ),
86
+ )
87
+
88
+ return params
89
+
90
+
91
+ def _get_vmap_in_axes_params(
92
+ eq_params_batch_dict: Dict[str, Array], params: Params | ParamsDict
93
+ ) -> tuple[Params]:
94
+ """
95
+ Return the input vmap axes when there is batch(es) of parameters to vmap
96
+ over. The latter are designated by keys in eq_params_batch_dict.
97
+ If eq_params_batch_dict is None (i.e. no additional parameter batch), we
98
+ return (None,).
99
+ """
100
+ if eq_params_batch_dict is None:
101
+ return (None,)
102
+ # We use pytree indexing of vmapped axes and vmap on axis
103
+ # 0 of the eq_parameters for which we have a batch
104
+ # this is for a fine-grained vmaping
105
+ # scheme over the params
106
+ vmap_in_axes_params = (
107
+ type(params)(
108
+ nn_params=None,
109
+ eq_params={
110
+ k: (0 if k in eq_params_batch_dict.keys() else None)
111
+ for k in params.eq_params.keys()
112
+ },
113
+ ),
114
+ )
115
+ return vmap_in_axes_params
jinns/plot/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from jinns.plot._plot import (
2
+ plot2d,
3
+ plot1d_slice,
4
+ plot1d_image,
5
+ )