jinns 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_operators.py CHANGED
@@ -18,8 +18,8 @@ from jinns.nn._abstract_pinn import AbstractPINN
18
18
 
19
19
  def _get_eq_type(
20
20
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
21
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
22
- ) -> Literal["nonstatio_PDE", "statio_PDE"]:
21
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None,
22
+ ) -> Literal["PDENonStatio", "PDEStatio"]:
23
23
  """
24
24
  But we filter out ODE from eq_type because we only have operators that does
25
25
  not work with ODEs so far
@@ -36,7 +36,7 @@ def divergence_rev(
36
36
  inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
37
37
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
38
38
  params: Params[Array],
39
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
39
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
40
40
  ) -> Float[Array, " "]:
41
41
  r"""
42
42
  Compute the divergence of a vector field $\mathbf{u}$, i.e.,
@@ -64,7 +64,7 @@ def divergence_rev(
64
64
  eq_type = _get_eq_type(u, eq_type)
65
65
 
66
66
  def scan_fun(_, i):
67
- if eq_type == "nonstatio_PDE":
67
+ if eq_type == "PDENonStatio":
68
68
  du_dxi = grad(lambda inputs, params: u(inputs, params)[1 + i])(
69
69
  inputs, params
70
70
  )[1 + i]
@@ -74,9 +74,9 @@ def divergence_rev(
74
74
  ]
75
75
  return _, du_dxi
76
76
 
77
- if eq_type == "nonstatio_PDE":
77
+ if eq_type == "PDENonStatio":
78
78
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0] - 1))
79
- elif eq_type == "statio_PDE":
79
+ elif eq_type == "PDEStatio":
80
80
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
81
81
  else:
82
82
  raise ValueError("Unexpected u.eq_type!")
@@ -87,7 +87,7 @@ def divergence_fwd(
87
87
  inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
88
88
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
89
89
  params: Params[Array],
90
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
90
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
91
91
  ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
92
92
  r"""
93
93
  Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
@@ -120,7 +120,7 @@ def divergence_fwd(
120
120
  eq_type = _get_eq_type(u, eq_type)
121
121
 
122
122
  def scan_fun(_, i):
123
- if eq_type == "nonstatio_PDE":
123
+ if eq_type == "PDENonStatio":
124
124
  tangent_vec = jnp.repeat(
125
125
  jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
126
126
  inputs.shape[0],
@@ -140,9 +140,9 @@ def divergence_fwd(
140
140
  )
141
141
  return _, du_dxi
142
142
 
143
- if eq_type == "nonstatio_PDE":
143
+ if eq_type == "PDENonStatio":
144
144
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1] - 1))
145
- elif eq_type == "statio_PDE":
145
+ elif eq_type == "PDEStatio":
146
146
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1]))
147
147
  else:
148
148
  raise ValueError("Unexpected u.eq_type!")
@@ -154,7 +154,7 @@ def laplacian_rev(
154
154
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
155
155
  params: Params[Array],
156
156
  method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
157
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
157
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
158
158
  ) -> Float[Array, " "]:
159
159
  r"""
160
160
  Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
@@ -196,22 +196,22 @@ def laplacian_rev(
196
196
  # computation and then discarding elements but for higher order derivatives
197
197
  # it might not be worth it. See other options below for computating the
198
198
  # Laplacian
199
- if eq_type == "nonstatio_PDE":
199
+ if eq_type == "PDENonStatio":
200
200
  u_ = lambda x: jnp.squeeze(
201
201
  u(jnp.concatenate([inputs[:1], x], axis=0), params)
202
202
  )
203
203
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs[1:])))
204
- if eq_type == "statio_PDE":
204
+ if eq_type == "PDEStatio":
205
205
  u_ = lambda inputs: jnp.squeeze(u(inputs, params))
206
206
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
207
207
  raise ValueError("Unexpected eq_type!")
208
208
  if method == "trace_hessian_t_x":
209
209
  # NOTE that it is unclear whether it is better to vectorially compute the
210
210
  # Hessian (despite a useless time dimension) as below
211
- if eq_type == "nonstatio_PDE":
211
+ if eq_type == "PDENonStatio":
212
212
  u_ = lambda inputs: jnp.squeeze(u(inputs, params))
213
213
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs))[1:])
214
- if eq_type == "statio_PDE":
214
+ if eq_type == "PDEStatio":
215
215
  u_ = lambda inputs: jnp.squeeze(u(inputs, params))
216
216
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
217
217
  raise ValueError("Unexpected eq_type!")
@@ -225,7 +225,7 @@ def laplacian_rev(
225
225
  u_ = lambda inputs: u(inputs, params).squeeze()
226
226
 
227
227
  def scan_fun(_, i):
228
- if eq_type == "nonstatio_PDE":
228
+ if eq_type == "PDENonStatio":
229
229
  d2u_dxi2 = grad(
230
230
  lambda inputs: grad(u_)(inputs)[1 + i],
231
231
  )(inputs)[1 + i]
@@ -236,11 +236,11 @@ def laplacian_rev(
236
236
  )(inputs)[i]
237
237
  return _, d2u_dxi2
238
238
 
239
- if eq_type == "nonstatio_PDE":
239
+ if eq_type == "PDENonStatio":
240
240
  _, trace_hessian = jax.lax.scan(
241
241
  scan_fun, {}, jnp.arange(inputs.shape[0] - 1)
242
242
  )
243
- elif eq_type == "statio_PDE":
243
+ elif eq_type == "PDEStatio":
244
244
  _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
245
245
  else:
246
246
  raise ValueError("Unexpected eq_type!")
@@ -253,7 +253,7 @@ def laplacian_fwd(
253
253
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
254
254
  params: Params[Array],
255
255
  method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
256
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
256
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
257
257
  ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
258
258
  r"""
259
259
  Compute the Laplacian of a **batched** scalar field $u$
@@ -302,7 +302,7 @@ def laplacian_fwd(
302
302
  if method == "loop":
303
303
 
304
304
  def scan_fun(_, i):
305
- if eq_type == "nonstatio_PDE":
305
+ if eq_type == "PDENonStatio":
306
306
  tangent_vec = jnp.repeat(
307
307
  jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
308
308
  inputs.shape[0],
@@ -323,17 +323,17 @@ def laplacian_fwd(
323
323
  __, d2u_dxi2 = jax.jvp(du_dxi_fun, (inputs,), (tangent_vec,))
324
324
  return _, d2u_dxi2
325
325
 
326
- if eq_type == "nonstatio_PDE":
326
+ if eq_type == "PDENonStatio":
327
327
  _, trace_hessian = jax.lax.scan(
328
328
  scan_fun, {}, jnp.arange(inputs.shape[-1] - 1)
329
329
  )
330
- elif eq_type == "statio_PDE":
330
+ elif eq_type == "PDEStatio":
331
331
  _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[-1]))
332
332
  else:
333
333
  raise ValueError("Unexpected eq_type!")
334
334
  return jnp.sum(trace_hessian, axis=0)
335
335
  if method == "trace_hessian_t_x":
336
- if eq_type == "nonstatio_PDE":
336
+ if eq_type == "PDENonStatio":
337
337
  # compute the Hessian including the batch dimension, get rid of the
338
338
  # (..,1,..) axis that is here because of the scalar output
339
339
  # if inputs.shape==(10,3) (1 for time, 2 for x_dim)
@@ -351,7 +351,7 @@ def laplacian_fwd(
351
351
  res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
352
352
  lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
353
353
  return lap[..., None]
354
- if eq_type == "statio_PDE":
354
+ if eq_type == "PDEStatio":
355
355
  # compute the Hessian including the batch dimension, get rid of the
356
356
  # (..,1,..) axis that is here because of the scalar output
357
357
  # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
@@ -369,7 +369,7 @@ def laplacian_fwd(
369
369
  return lap[..., None]
370
370
  raise ValueError("Unexpected eq_type!")
371
371
  if method == "trace_hessian_x":
372
- if eq_type == "statio_PDE":
372
+ if eq_type == "PDEStatio":
373
373
  # compute the Hessian including the batch dimension, get rid of the
374
374
  # (..,1,..) axis that is here because of the scalar output
375
375
  # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
@@ -394,7 +394,7 @@ def vectorial_laplacian_rev(
394
394
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
395
395
  params: Params[Array],
396
396
  dim_out: int | None = None,
397
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
397
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
398
398
  ) -> Float[Array, " dim_out"]:
399
399
  r"""
400
400
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
@@ -448,7 +448,7 @@ def vectorial_laplacian_fwd(
448
448
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
449
449
  params: Params[Array],
450
450
  dim_out: int | None = None,
451
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
451
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
452
452
  ) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
453
453
  r"""
454
454
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
@@ -13,7 +13,7 @@ class AbstractPINN(eqx.Module):
13
13
  https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
14
14
  """
15
15
 
16
- eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
16
+ eq_type: eqx.AbstractVar[Literal["ODE", "PDEStatio", "PDENonStatio"]]
17
17
 
18
18
  @abc.abstractmethod
19
19
  def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
jinns/nn/_hyperpinn.py CHANGED
@@ -67,9 +67,9 @@ class HyperPINN(PINN):
67
67
  eq_type : str
68
68
  A string with three possibilities.
69
69
  "ODE": the HyperPINN is called with one input `t`.
70
- "statio_PDE": the HyperPINN is called with one input `x`, `x`
70
+ "PDEStatio": the HyperPINN is called with one input `x`, `x`
71
71
  can be high dimensional.
72
- "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
72
+ "PDENonStatio": the HyperPINN is called with two inputs `t` and `x`, `x`
73
73
  can be high dimensional.
74
74
  **Note**: the input dimension as given in eqx_list has to match the sum
75
75
  of the dimension of `t` + the dimension of `x` or the output dimension
@@ -192,7 +192,7 @@ class HyperPINN(PINN):
192
192
  hyper = eqx.combine(params.nn_params, self.static_hyper)
193
193
 
194
194
  eq_params_batch = jnp.concatenate(
195
- [getattr(params.eq_params, k).flatten() for k in self.hyperparams],
195
+ [getattr(params.eq_params, k).flatten() for k in self.hyperparams], # pylint: disable=E1133
196
196
  axis=0,
197
197
  )
198
198
 
@@ -214,7 +214,7 @@ class HyperPINN(PINN):
214
214
  def create(
215
215
  cls,
216
216
  *,
217
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
217
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
218
218
  hyperparams: list[str],
219
219
  hypernet_input_size: int,
220
220
  key: PRNGKeyArray | None = None,
@@ -257,9 +257,9 @@ class HyperPINN(PINN):
257
257
  eq_type
258
258
  A string with three possibilities.
259
259
  "ODE": the HyperPINN is called with one input `t`.
260
- "statio_PDE": the HyperPINN is called with one input `x`, `x`
260
+ "PDEStatio": the HyperPINN is called with one input `x`, `x`
261
261
  can be high dimensional.
262
- "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
262
+ "PDENonStatio": the HyperPINN is called with two inputs `t` and `x`, `x`
263
263
  can be high dimensional.
264
264
  **Note**: the input dimension as given in eqx_list has to match the sum
265
265
  of the dimension of `t` + the dimension of `x` or the output dimension
jinns/nn/_mlp.py CHANGED
@@ -95,7 +95,7 @@ class PINN_MLP(PINN):
95
95
  def create(
96
96
  cls,
97
97
  *,
98
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
98
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
99
99
  key: PRNGKeyArray | None = None,
100
100
  eqx_network: eqx.nn.MLP | MLP | None = None,
101
101
  eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
@@ -130,9 +130,9 @@ class PINN_MLP(PINN):
130
130
  eq_type
131
131
  A string with three possibilities.
132
132
  "ODE": the MLP is called with one input `t`.
133
- "statio_PDE": the MLP is called with one input `x`, `x`
133
+ "PDEStatio": the MLP is called with one input `x`, `x`
134
134
  can be high dimensional.
135
- "nonstatio_PDE": the MLP is called with two inputs `t` and `x`, `x`
135
+ "PDENonStatio": the MLP is called with two inputs `t` and `x`, `x`
136
136
  can be high dimensional.
137
137
  **Note**: the input dimension as given in eqx_list has to match the sum
138
138
  of the dimension of `t` + the dimension of `x` or the output dimension
jinns/nn/_pinn.py CHANGED
@@ -50,12 +50,12 @@ class PINN(AbstractPINN):
50
50
  when the PINN is also used to output equation parameters for example
51
51
  Note that it must be a slice and not an integer (a preprocessing of the
52
52
  user provided argument takes care of it).
53
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
53
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
54
54
  A string with three possibilities.
55
55
  "ODE": the PINN is called with one input `t`.
56
- "statio_PDE": the PINN is called with one input `x`, `x`
56
+ "PDEStatio": the PINN is called with one input `x`, `x`
57
57
  can be high dimensional.
58
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
58
+ "PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
59
59
  can be high dimensional.
60
60
  **Note**: the input dimension as given in eqx_list has to match the sum
61
61
  of the dimension of `t` + the dimension of `x` or the output dimension
@@ -83,11 +83,11 @@ class PINN(AbstractPINN):
83
83
  Raises
84
84
  ------
85
85
  RuntimeError
86
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
87
- "nonstatio_PDE"]`
86
+ If the parameter value for eq_type is not in `["ODE", "PDEStatio",
87
+ "PDENonStatio"]`
88
88
  """
89
89
 
90
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
90
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"] = eqx.field(
91
91
  static=True, kw_only=True
92
92
  )
93
93
  slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
@@ -108,7 +108,7 @@ class PINN(AbstractPINN):
108
108
  static: PINN = eqx.field(init=False, static=True)
109
109
 
110
110
  def __post_init__(self, eqx_network):
111
- if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
111
+ if self.eq_type not in ["ODE", "PDEStatio", "PDENonStatio"]:
112
112
  raise RuntimeError("Wrong parameter value for eq_type")
113
113
  # saving the static part of the model and initial parameters
114
114
 
jinns/nn/_ppinn.py CHANGED
@@ -31,12 +31,12 @@ class PPINN_MLP(PINN):
31
31
  when the PINN is also used to output equation parameters for example
32
32
  Note that it must be a slice and not an integer (a preprocessing of the
33
33
  user provided argument takes care of it).
34
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
34
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
35
35
  A string with three possibilities.
36
36
  "ODE": the PPINN is called with one input `t`.
37
- "statio_PDE": the PPINN is called with one input `x`, `x`
37
+ "PDEStatio": the PPINN is called with one input `x`, `x`
38
38
  can be high dimensional.
39
- "nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
39
+ "PDENonStatio": the PPINN is called with two inputs `t` and `x`, `x`
40
40
  can be high dimensional.
41
41
  **Note**: the input dimension as given in eqx_list has to match the sum
42
42
  of the dimension of `t` + the dimension of `x` or the output dimension
@@ -125,7 +125,7 @@ class PPINN_MLP(PINN):
125
125
  cls,
126
126
  *,
127
127
  key: PRNGKeyArray | None = None,
128
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
128
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
129
129
  eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
130
130
  eqx_list_list: (
131
131
  list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
@@ -158,9 +158,9 @@ class PPINN_MLP(PINN):
158
158
  eq_type
159
159
  A string with three possibilities.
160
160
  "ODE": the PPINN MLP is called with one input `t`.
161
- "statio_PDE": the PPINN MLP is called with one input `x`, `x`
161
+ "PDEStatio": the PPINN MLP is called with one input `x`, `x`
162
162
  can be high dimensional.
163
- "nonstatio_PDE": the PPINN MLP is called with two inputs `t` and `x`, `x`
163
+ "PDENonStatio": the PPINN MLP is called with two inputs `t` and `x`, `x`
164
164
  can be high dimensional.
165
165
  **Note**: the input dimension as given in eqx_list has to match the sum
166
166
  of the dimension of `t` + the dimension of `x` or the output dimension
jinns/nn/_spinn.py CHANGED
@@ -21,12 +21,12 @@ class SPINN(AbstractPINN):
21
21
  used for non-stationnary equations.
22
22
  r : int
23
23
  An integer. The dimension of the embedding.
24
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
24
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
25
25
  A string with three possibilities.
26
26
  "ODE": the PINN is called with one input `t`.
27
- "statio_PDE": the PINN is called with one input `x`, `x`
27
+ "PDEStatio": the PINN is called with one input `x`, `x`
28
28
  can be high dimensional.
29
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
29
+ "PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
30
30
  can be high dimensional.
31
31
  **Note**: the input dimension as given in eqx_list has to match the sum
32
32
  of the dimension of `t` + the dimension of `x`.
@@ -49,7 +49,7 @@ class SPINN(AbstractPINN):
49
49
 
50
50
  """
51
51
 
52
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
52
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"] = eqx.field(
53
53
  static=True, kw_only=True
54
54
  )
55
55
  d: int = eqx.field(static=True, kw_only=True)
jinns/nn/_spinn_mlp.py CHANGED
@@ -78,7 +78,7 @@ class SPINN_MLP(SPINN):
78
78
  d: int,
79
79
  r: int,
80
80
  eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
81
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
81
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
82
82
  m: int = 1,
83
83
  filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
84
84
  ) -> tuple[Self, SPINN]:
@@ -114,12 +114,12 @@ class SPINN_MLP(SPINN):
114
114
  (jax.nn.tanh,),
115
115
  (eqx.nn.Linear, 20, r * m)
116
116
  )`.
117
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
117
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
118
118
  A string with three possibilities.
119
119
  "ODE": the PINN is called with one input `t`.
120
- "statio_PDE": the PINN is called with one input `x`, `x`
120
+ "PDEStatio": the PINN is called with one input `x`, `x`
121
121
  can be high dimensional.
122
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
122
+ "PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
123
123
  can be high dimensional.
124
124
  **Note**: the input dimension as given in eqx_list has to match the sum
125
125
  of the dimension of `t` + the dimension of `x`.
@@ -150,11 +150,11 @@ class SPINN_MLP(SPINN):
150
150
  Raises
151
151
  ------
152
152
  RuntimeError
153
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
154
- "nonstatio_PDE"]` and for various failing checks
153
+ If the parameter value for eq_type is not in `["ODE", "PDEStatio",
154
+ "PDENonStatio"]` and for various failing checks
155
155
  """
156
156
 
157
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
157
+ if eq_type not in ["ODE", "PDEStatio", "PDENonStatio"]:
158
158
  raise RuntimeError("Wrong parameter value for eq_type")
159
159
 
160
160
  def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
jinns/solver/_rar.py CHANGED
@@ -10,6 +10,7 @@ from jax import vmap
10
10
  import jax.numpy as jnp
11
11
  import equinox as eqx
12
12
 
13
+ from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
13
14
  from jinns.data._DataGeneratorODE import DataGeneratorODE
14
15
  from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
15
16
  from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
@@ -176,16 +177,25 @@ def _rar_step_init(
176
177
  )
177
178
 
178
179
  data = eqx.tree_at(lambda m: m.key, data, new_key)
179
-
180
- v_dyn_loss = vmap(
181
- lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
182
- )
183
- dyn_on_s = v_dyn_loss(new_samples)
184
-
185
- if dyn_on_s.ndim > 1:
186
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
187
180
  else:
188
- mse_on_s = dyn_on_s**2
181
+ raise ValueError("Wrong DataGenerator type")
182
+
183
+ v_dyn_loss = jax.tree.map(
184
+ lambda d: vmap(
185
+ lambda inputs: d.evaluate(inputs, loss.u, params),
186
+ ),
187
+ loss.dynamic_loss,
188
+ is_leaf=lambda x: isinstance(x, (ODE, PDEStatio, PDENonStatio)),
189
+ )
190
+ dyn_on_s = jax.tree.map(lambda d: d(new_samples), v_dyn_loss)
191
+
192
+ mse_on_s = jax.tree.reduce(
193
+ jnp.add,
194
+ jax.tree.map(
195
+ lambda v: (jnp.linalg.norm(v, axis=-1) ** 2).flatten(), dyn_on_s
196
+ ),
197
+ 0,
198
+ )
189
199
 
190
200
  ## Select the m points with higher dynamic loss
191
201
  higher_residual_idx = jax.lax.dynamic_slice(
jinns/solver/_solve.py CHANGED
@@ -272,6 +272,7 @@ def solve(
272
272
  params=init_params,
273
273
  last_non_nan_params=init_params,
274
274
  opt_state=opt_state,
275
+ # params_mask=params_mask,
275
276
  )
276
277
  optimization_extra = OptimizationExtraContainer(
277
278
  curr_seq=curr_seq,
@@ -430,7 +431,9 @@ def solve(
430
431
  return (
431
432
  i,
432
433
  loss,
433
- OptimizationContainer(params, last_non_nan_params, opt_state),
434
+ OptimizationContainer(
435
+ params, last_non_nan_params, opt_state
436
+ ), # , params_mask),
434
437
  OptimizationExtraContainer(
435
438
  curr_seq,
436
439
  best_iter_id,
jinns/solver/_utils.py CHANGED
@@ -606,18 +606,24 @@ def _check_batch_size(other_data, main_data, attr_name):
606
606
  " vectorization"
607
607
  )
608
608
  if isinstance(main_data, DataGeneratorParameter):
609
- if main_data.param_batch_size is not None:
610
- if getattr(other_data, attr_name) != main_data.param_batch_size:
611
- raise ValueError(
612
- f"{other_data.__class__}.{attr_name} must be equal"
613
- f" to {main_data.__class__}.param_batch_size for correct"
614
- " vectorization"
615
- )
616
- else:
617
- if main_data.n is not None:
618
- if getattr(other_data, attr_name) != main_data.n:
609
+ batch_size = getattr(other_data, attr_name) # this can be a tuple with
610
+ # DataGeneratorObservations
611
+ if not isinstance(batch_size, tuple):
612
+ batch_size = (batch_size,)
613
+
614
+ for bs in batch_size:
615
+ if main_data.param_batch_size is not None:
616
+ if bs != main_data.param_batch_size:
619
617
  raise ValueError(
620
618
  f"{other_data.__class__}.{attr_name} must be equal"
621
- f" to {main_data.__class__}.n for correct"
619
+ f" to {main_data.__class__}.param_batch_size for correct"
622
620
  " vectorization"
623
621
  )
622
+ else:
623
+ if main_data.n is not None:
624
+ if bs != main_data.n:
625
+ raise ValueError(
626
+ f"{other_data.__class__}.{attr_name} must be equal"
627
+ f" to {main_data.__class__}.n for correct"
628
+ " vectorization"
629
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jinns
3
- Version: 1.7.0
3
+ Version: 1.7.1
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -20,14 +20,24 @@ Requires-Dist: optax>=0.2.6
20
20
  Requires-Dist: equinox>=0.13.2
21
21
  Requires-Dist: matplotlib
22
22
  Requires-Dist: jaxtyping
23
+ Requires-Dist: pyright>=1.1.407
24
+ Requires-Dist: pytest>=9.0.2
23
25
  Provides-Extra: notebook
24
26
  Requires-Dist: jupyter; extra == "notebook"
25
27
  Requires-Dist: seaborn; extra == "notebook"
26
28
  Requires-Dist: pandas; extra == "notebook"
27
- Requires-Dist: pytest; extra == "notebook"
28
- Requires-Dist: pre-commit; extra == "notebook"
29
- Requires-Dist: pyright; extra == "notebook"
30
29
  Requires-Dist: diffrax; extra == "notebook"
30
+ Provides-Extra: gpu-cuda13
31
+ Requires-Dist: jax[cuda13]>=0.8.1; extra == "gpu-cuda13"
32
+ Provides-Extra: gpu-cuda12
33
+ Requires-Dist: jax[cuda12]>=0.8.1; extra == "gpu-cuda12"
34
+ Provides-Extra: cpu
35
+ Requires-Dist: jax>=0.8.1; extra == "cpu"
36
+ Provides-Extra: dev
37
+ Requires-Dist: pre-commit; extra == "dev"
38
+ Requires-Dist: pytest; extra == "dev"
39
+ Requires-Dist: pylint; extra == "dev"
40
+ Requires-Dist: ruff; extra == "dev"
31
41
  Dynamic: license-file
32
42
 
33
43
  jinns
@@ -1,36 +1,36 @@
1
1
  jinns/__init__.py,sha256=f8ZCZH95C7U8NfJ3M-00kf1IBbdgbwRNBhlluiOCGJE,530
2
2
  jinns/data/_AbstractDataGenerator.py,sha256=le5plNhOE7hV72SC6p2xhWxljnGeStPc2kitzeMR_ts,535
3
- jinns/data/_Batchs.py,sha256=UZfJWIKAFLQJxOtWN1ybJdoSQCO6PWk3SAgV9YmNVnI,2809
3
+ jinns/data/_Batchs.py,sha256=mtHDlsen1vBBmFgGzynPmFfmBFC7zKcnnVEg_JnNlig,2857
4
4
  jinns/data/_CubicMeshPDENonStatio.py,sha256=dtev6j7HuVZ_IBGnKVKOlJTyZVjnW_8G1VRGpcL40VU,22922
5
5
  jinns/data/_CubicMeshPDEStatio.py,sha256=g4H9wI0FkA9sdnlSxp00QArq9IjhPrmsGgSleLdNpI4,23314
6
- jinns/data/_DataGeneratorODE.py,sha256=XOGeX9m5J9CkkGx2alJ5G7FZYGn5cBUaFpO-3-MKvdc,7724
7
- jinns/data/_DataGeneratorObservations.py,sha256=h2bS-oOUi1rYvhDHy8xT-N9voqtb_IOelL2asnXCybw,9765
6
+ jinns/data/_DataGeneratorODE.py,sha256=7m3mhMk-o02Jq_ZieWktFG5w6mVwnWPs1XYiYqoJ0DA,7731
7
+ jinns/data/_DataGeneratorObservations.py,sha256=Lw1Fno4UDJ0LRkhFJ0tDtkkW-yWJLeswCWja504pUDo,27781
8
8
  jinns/data/_DataGeneratorParameter.py,sha256=m7WCWBtmNptF8qQSWGPTQEFVzkJb8kJpidKAFBmoBiQ,10205
9
9
  jinns/data/__init__.py,sha256=DEzEmoD5TmjHCaG6pS2jmFDzuhZn1ZDpFdVa2v0jCe0,703
10
10
  jinns/data/_utils.py,sha256=xKHf7NksJ_AmrtEcpJsh7WSEvI3Yk98_cM5kmXSfmx0,5596
11
11
  jinns/experimental/__init__.py,sha256=DT9e57zbjfzPeRnXemGUqnGd--MhV77FspChT0z4YrE,410
12
12
  jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
13
13
  jinns/loss/_DynamicLoss.py,sha256=FXVraH97tIWA5ByO-V-r54Nhwd_EvHOfr0e2qjrfdf0,22667
14
- jinns/loss/_DynamicLossAbstract.py,sha256=bTiDZHqkHDdvfj6hcjPcESYzwzqOd2LtR_h7GPAqQpY,13602
15
- jinns/loss/_LossODE.py,sha256=c2SBstJgy1yTTGcLZrnrM3XFYS67AZ8SwbcPj0cdfXE,15894
16
- jinns/loss/_LossPDE.py,sha256=uhOGJ9440WaCaLasMbqp8pxf9obXYbvFWbqJMYuxJUU,39719
14
+ jinns/loss/_DynamicLossAbstract.py,sha256=mm687ly-r7ONZPJaQN6kQa9CwM05msiZt9cCkQ4Ue54,13699
15
+ jinns/loss/_LossODE.py,sha256=mbsK6n61J3Fx4zl1VvJCLigowV-UPxIJoxn-Zf_IwH0,18039
16
+ jinns/loss/_LossPDE.py,sha256=rjy5mKEkFOrTqpRf7W51JEVnIysnD7RhQ33aAvZMcfg,43056
17
17
  jinns/loss/__init__.py,sha256=z5xYgBipNFf66__5BqQc6R_8r4F6A3TXL60YjsM8Osk,1287
18
- jinns/loss/_abstract_loss.py,sha256=wBpB2mw_JizimiNHYIs-431jZ87mxvzamTFkX_rnQZ4,6872
19
- jinns/loss/_boundary_conditions.py,sha256=9HGw1cGLfmEilP4V4B2T0zl0YP1kNtrtXVLQNiBmWgc,12464
18
+ jinns/loss/_abstract_loss.py,sha256=MG-nC_ApYNhnTGLvW1_vx8VXXJLJ2q3Ird9sbUZN5fk,9000
19
+ jinns/loss/_boundary_conditions.py,sha256=WqPAIcG5DbY-K_sYYvi_cr4TPoi0ajIGvnmz9KSw3IE,12458
20
20
  jinns/loss/_loss_components.py,sha256=Y5GGNuAS_tHl-AzIMBxf3yKj0N2QCdVdTY4jC7NDuic,442
21
- jinns/loss/_loss_utils.py,sha256=eJ4JcBm396LHx7Tti88ZQrLcKqVL1oSfFGT23VNkytQ,11949
22
- jinns/loss/_loss_weight_updates.py,sha256=Q5RJkulMjP5G3twKEGmc46QeQJSCb7yyS7OudzRqQkY,6775
23
- jinns/loss/_loss_weights.py,sha256=x0URuCaFXULlVDwlL9ZEy45L33tj1Fa0VJUQlSPhGkA,2416
24
- jinns/loss/_operators.py,sha256=Ds5yRH7hu-jaGRp7PYbt821BgYuEvgWHufWhYgdMjw0,22909
21
+ jinns/loss/_loss_utils.py,sha256=thnV4rA77tH-exwe5Y2dbPOTDAmw44sQJmYyRGe8cRY,12007
22
+ jinns/loss/_loss_weight_updates.py,sha256=WaQZl4TE1eXybLGjMXaXRQhUtA_7Cnyc1r5vTyr4_98,7683
23
+ jinns/loss/_loss_weights.py,sha256=mPfIXdCoTAcnxsgNK_KjMTBUbk_Q3m4BmFF229ghEpQ,2640
24
+ jinns/loss/_operators.py,sha256=RsBJzvvADB1uhJQSbvg1e90KPw8csQR5GKIidu5Xsp8,22874
25
25
  jinns/nn/__init__.py,sha256=gwE48oqB_FsSIE-hUvCLz0jPaqX350LBxzH6ueFWYk4,456
26
- jinns/nn/_abstract_pinn.py,sha256=E17TFjxgs63BHl3dUKwajqkBuggVbyNVtovX0ePPzr4,602
27
- jinns/nn/_hyperpinn.py,sha256=x_hBWqokBbURMOdIG_0EBNGJjxAzjlUZ4__fyb-uasA,20245
28
- jinns/nn/_mlp.py,sha256=ajVWRTtGVnmRjuaylEMlsK-3H5LMqKDvhUsMNnH_rIM,8935
29
- jinns/nn/_pinn.py,sha256=nRBIdjTir1ViI9hBO4_4DY6uS-1P8Xv-TDLU-O_qzTk,8332
30
- jinns/nn/_ppinn.py,sha256=qBMenzj0TpapGxyl7nSWb6mER-Ny42JrcOd7Wld19hU,10144
26
+ jinns/nn/_abstract_pinn.py,sha256=jp1Wi2LDBvSkJempliQq4m7OaIbYr8d16elnJX8movk,600
27
+ jinns/nn/_hyperpinn.py,sha256=A4nNnID7XVsU7kDafIZo7OchRx9EmoSHseoH4EmY0m8,20264
28
+ jinns/nn/_mlp.py,sha256=VqObd7gLGIO25N4XVfHK_XwvhR2dgeTR-kOcRLTH0Os,8931
29
+ jinns/nn/_pinn.py,sha256=f5OyiZCgyq_wubI10rNYSlY00letMsXYDSHcISuvvwM,8322
30
+ jinns/nn/_ppinn.py,sha256=wvzeT8CyYeiQ2q6BTvNTz3L7g7FPemOQcPIGVDy34Rs,10136
31
31
  jinns/nn/_save_load.py,sha256=4pGDPZPKzDmY5UzRtA_zxpDK4hFL2LuDxexWW2PnHtw,7656
32
- jinns/nn/_spinn.py,sha256=a0ZsnzyXi-oNd6i8pAwF1tdjJQU51pbrGg608O6ykQw,4352
33
- jinns/nn/_spinn_mlp.py,sha256=7w-WsdcoLHVWwkK2DtXRKRtSrcJkEXJ2_n27VmXI2j4,7651
32
+ jinns/nn/_spinn.py,sha256=MGaPQRA-qVltshESB7IP_kft0kINjijPhEV5kBfKDvI,4346
33
+ jinns/nn/_spinn_mlp.py,sha256=NWSNqVOUPz-Zk2JNYZKbmcYWtXHDrxN2aIYiJH4jK1k,7641
34
34
  jinns/nn/_utils.py,sha256=SoaYvNM6V7M3GcXK3diqO7B9_zoBXqSux2fsTOcouSA,1018
35
35
  jinns/parameters/__init__.py,sha256=T_BhKT77X5nei2pYj-QIOKCyD6j2UCgSgAY3LnNdkCo,333
36
36
  jinns/parameters/_derivative_keys.py,sha256=laMZoT7ZZa3ZSTuvA-TKM6-VX32vKZFfA5ThDWhj4v4,21084
@@ -38,10 +38,10 @@ jinns/parameters/_params.py,sha256=A6F5HAbxaKPonsH5IrT3ERNSy60OGLUD0eV9hYNj28U,4
38
38
  jinns/plot/__init__.py,sha256=KPHX0Um4FbciZO1yD8kjZbkaT8tT964Y6SE2xCQ4eDU,135
39
39
  jinns/plot/_plot.py,sha256=-A5auNeElaz2_8UzVQJQE4143ZFg0zgMjStU7kwttEY,11565
40
40
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
- jinns/solver/_rar.py,sha256=vSVTnCGCusI1vTZCvIkP2_G8we44G_42yZHx2sOK9DE,10291
42
- jinns/solver/_solve.py,sha256=QX1vBec9pk9UZdjA4oqcqWNnvwb8p10Qzm24szwoZJs,20046
41
+ jinns/solver/_rar.py,sha256=IFaH0-b-uPp7fs1rX14LY2oy3cXH2NN6QqjC4rdm6n8,10659
42
+ jinns/solver/_solve.py,sha256=3JDlX19o3UCpIfrTyFLgWh3ihAwGbl1oB0NsdbBke30,20130
43
43
  jinns/solver/_solve_alternate.py,sha256=9R1idmQMhkLz_QsxwjsfWcS_f_Cvg71NS43MfNbpXSo,31664
44
- jinns/solver/_utils.py,sha256=dNr-XipxSy3Vxo_DqudMKUHNgc6rzvZ_efTy--GnesQ,22766
44
+ jinns/solver/_utils.py,sha256=UMq7jhrShVeVUbcJqvjNdLvczDOP1x6XolF0jWjNs0A,23002
45
45
  jinns/utils/_DictToModuleMeta.py,sha256=kdCAZ6U7L_ZNsGnvTfk3r-85RaV19Qq6y6zdhvI_hGM,3043
46
46
  jinns/utils/_ItemizableModule.py,sha256=KdjsisFddGOUTA7-1jXlrkFxnX5sIVdtyNDxOaEcEYw,634
47
47
  jinns/utils/__init__.py,sha256=k2kWDahXsLhSlsbeB_W73WLaVSvPG_S6mrNG_5fW_Mo,121
@@ -50,9 +50,9 @@ jinns/utils/_types.py,sha256=PftmBr-hV33U-aB_TzTYNLD42ogRWmZKA02CPdxQWBg,2482
50
50
  jinns/utils/_utils.py,sha256=M7NXX9ok-BkH5o_xo74PB1_Cc8XiDipSl51rq82dTH4,2821
51
51
  jinns/validation/__init__.py,sha256=FTyUO-v1b8Tv-FDSQsntrH7zl9E0ENexqKMT_dFRkYo,124
52
52
  jinns/validation/_validation.py,sha256=8p6sMKiBAvA6JNm65hjkMj0997LJ0BkyCREEh0AnPVE,4803
53
- jinns-1.7.0.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
54
- jinns-1.7.0.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
55
- jinns-1.7.0.dist-info/METADATA,sha256=nv5KCuvdeLV2I3kWjckhMCsKe7KwRJ3NoPgiwmK1WBk,5651
56
- jinns-1.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
57
- jinns-1.7.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
58
- jinns-1.7.0.dist-info/RECORD,,
53
+ jinns-1.7.1.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
54
+ jinns-1.7.1.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
55
+ jinns-1.7.1.dist-info/METADATA,sha256=BuWK-l7JNmKFJ9silTxgpeJQqNgX3JRLZfo0WgmGb4Y,5982
56
+ jinns-1.7.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
57
+ jinns-1.7.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
58
+ jinns-1.7.1.dist-info/RECORD,,