jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/nn/_spinn.py CHANGED
@@ -21,12 +21,12 @@ class SPINN(AbstractPINN):
21
21
  used for non-stationnary equations.
22
22
  r : int
23
23
  An integer. The dimension of the embedding.
24
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
24
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
25
25
  A string with three possibilities.
26
26
  "ODE": the PINN is called with one input `t`.
27
- "statio_PDE": the PINN is called with one input `x`, `x`
27
+ "PDEStatio": the PINN is called with one input `x`, `x`
28
28
  can be high dimensional.
29
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
29
+ "PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
30
30
  can be high dimensional.
31
31
  **Note**: the input dimension as given in eqx_list has to match the sum
32
32
  of the dimension of `t` + the dimension of `x`.
@@ -49,7 +49,7 @@ class SPINN(AbstractPINN):
49
49
 
50
50
  """
51
51
 
52
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
52
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"] = eqx.field(
53
53
  static=True, kw_only=True
54
54
  )
55
55
  d: int = eqx.field(static=True, kw_only=True)
jinns/nn/_spinn_mlp.py CHANGED
@@ -78,7 +78,7 @@ class SPINN_MLP(SPINN):
78
78
  d: int,
79
79
  r: int,
80
80
  eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
81
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
81
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
82
82
  m: int = 1,
83
83
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
84
84
  ) -> tuple[Self, SPINN]:
@@ -114,12 +114,12 @@ class SPINN_MLP(SPINN):
114
114
  (jax.nn.tanh,),
115
115
  (eqx.nn.Linear, 20, r * m)
116
116
  )`.
117
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
117
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
118
118
  A string with three possibilities.
119
119
  "ODE": the PINN is called with one input `t`.
120
- "statio_PDE": the PINN is called with one input `x`, `x`
120
+ "PDEStatio": the PINN is called with one input `x`, `x`
121
121
  can be high dimensional.
122
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
122
+ "PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
123
123
  can be high dimensional.
124
124
  **Note**: the input dimension as given in eqx_list has to match the sum
125
125
  of the dimension of `t` + the dimension of `x`.
@@ -150,11 +150,11 @@ class SPINN_MLP(SPINN):
150
150
  Raises
151
151
  ------
152
152
  RuntimeError
153
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
154
- "nonstatio_PDE"]` and for various failing checks
153
+ If the parameter value for eq_type is not in `["ODE", "PDEStatio",
154
+ "PDENonStatio"]` and for various failing checks
155
155
  """
156
156
 
157
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
157
+ if eq_type not in ["ODE", "PDEStatio", "PDENonStatio"]:
158
158
  raise RuntimeError("Wrong parameter value for eq_type")
159
159
 
160
160
  def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
@@ -47,9 +47,9 @@ class DerivativeKeysODE(eqx.Module):
47
47
  [`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
48
48
  respect to the neural network parameters *and* the equation parameters, or only some of them.
49
49
 
50
- To do so, user can either use strings or a `Params` object
51
- with PyTree structure matching the parameters of the problem at
52
- hand, and booleans indicating if gradient is to be taken or not. Internally,
50
+ To do so, user can either use strings or a `Params[bool]` object
51
+ with PyTree structure matching the parameters of the problem (`Params[Array]`) at
52
+ hand, and leaves being booleans indicating if gradient is to be taken or not. Internally,
53
53
  a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
54
54
  computing each loss term.
55
55
 
@@ -156,12 +156,12 @@ class DerivativeKeysODE(eqx.Module):
156
156
  """
157
157
  Construct the DerivativeKeysODE from strings. For each term of the
158
158
  loss, specify whether to differentiate wrt the neural network
159
- parameters, the equation parameters or both. The `Params` object, which
159
+ parameters, the equation parameters or both. The `Params[Array]` object, which
160
160
  contains the actual array of parameters must be passed to
161
161
  construct the fields with the appropriate PyTree structure.
162
162
 
163
163
  !!! note
164
- You can mix strings and `Params` if you need granularity.
164
+ You can mix strings and `Params[bool]` if you need granularity.
165
165
 
166
166
  Parameters
167
167
  ----------
@@ -498,7 +498,14 @@ def _set_derivatives(
498
498
  `Params(nn_params=True | False, eq_params={"alpha":True | False,
499
499
  "beta":True | False})`.
500
500
  """
501
-
501
+ assert jax.tree.structure(params_.eq_params) == jax.tree.structure(
502
+ derivative_mask.eq_params
503
+ ), (
504
+ "The derivative "
505
+ "mask for eq_params does not have the same tree structure as "
506
+ "Params.eq_params. This is often due to a wrong Params[bool] "
507
+ "passed when initializing the derivative key object."
508
+ )
502
509
  return Params(
503
510
  nn_params=jax.lax.cond(
504
511
  derivative_mask.nn_params,
@@ -2,6 +2,7 @@
2
2
  Formalize the data structure for the parameters
3
3
  """
4
4
 
5
+ from __future__ import annotations
5
6
  from dataclasses import fields
6
7
  from typing import Generic, TypeVar
7
8
  import equinox as eqx
@@ -60,6 +61,15 @@ class Params(eqx.Module, Generic[T]):
60
61
  else:
61
62
  self.eq_params = eq_params
62
63
 
64
+ def partition(self, mask: Params[bool] | None):
65
+ """
66
+ following the boolean mask, partition into two Params
67
+ """
68
+ if mask is not None:
69
+ return eqx.partition(self, mask)
70
+ else:
71
+ return self, None
72
+
63
73
 
64
74
  def update_eq_params(
65
75
  params: Params[Array],
jinns/solver/_rar.py CHANGED
@@ -10,6 +10,7 @@ from jax import vmap
10
10
  import jax.numpy as jnp
11
11
  import equinox as eqx
12
12
 
13
+ from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
13
14
  from jinns.data._DataGeneratorODE import DataGeneratorODE
14
15
  from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
15
16
  from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
@@ -176,16 +177,25 @@ def _rar_step_init(
176
177
  )
177
178
 
178
179
  data = eqx.tree_at(lambda m: m.key, data, new_key)
179
-
180
- v_dyn_loss = vmap(
181
- lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
182
- )
183
- dyn_on_s = v_dyn_loss(new_samples)
184
-
185
- if dyn_on_s.ndim > 1:
186
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
187
180
  else:
188
- mse_on_s = dyn_on_s**2
181
+ raise ValueError("Wrong DataGenerator type")
182
+
183
+ v_dyn_loss = jax.tree.map(
184
+ lambda d: vmap(
185
+ lambda inputs: d.evaluate(inputs, loss.u, params),
186
+ ),
187
+ loss.dynamic_loss,
188
+ is_leaf=lambda x: isinstance(x, (ODE, PDEStatio, PDENonStatio)),
189
+ )
190
+ dyn_on_s = jax.tree.map(lambda d: d(new_samples), v_dyn_loss)
191
+
192
+ mse_on_s = jax.tree.reduce(
193
+ jnp.add,
194
+ jax.tree.map(
195
+ lambda v: (jnp.linalg.norm(v, axis=-1) ** 2).flatten(), dyn_on_s
196
+ ),
197
+ 0,
198
+ )
189
199
 
190
200
  ## Select the m points with higher dynamic loss
191
201
  higher_residual_idx = jax.lax.dynamic_slice(