jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import Self, Literal, Callable, TypeVar, Generic, Any
5
- from jaxtyping import PRNGKeyArray, Array, PyTree, Float
4
+ import warnings
5
+ from typing import Self, Literal, Callable, TypeVar, Generic, Any, get_args
6
+ from dataclasses import InitVar
7
+ from jaxtyping import Array, PyTree, Float, PRNGKeyArray
6
8
  import equinox as eqx
7
9
  import jax
8
10
  import jax.numpy as jnp
9
11
  import optax
10
12
  from jinns.parameters._params import Params
11
- from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
12
- from jinns.utils._types import AnyLossComponents, AnyBatch, AnyLossWeights
13
+ from jinns.loss._loss_weight_updates import (
14
+ soft_adapt,
15
+ lr_annealing,
16
+ ReLoBRaLo,
17
+ prior_loss,
18
+ )
19
+ from jinns.utils._types import (
20
+ AnyLossComponents,
21
+ AnyBatch,
22
+ AnyLossWeights,
23
+ AnyDerivativeKeys,
24
+ )
13
25
 
14
26
  L = TypeVar(
15
27
  "L", bound=AnyLossWeights
@@ -25,31 +37,85 @@ C = TypeVar(
25
37
  "C", bound=AnyLossComponents[Array | None]
26
38
  ) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
27
39
 
40
+ DK = TypeVar("DK", bound=AnyDerivativeKeys)
41
+
28
42
  # In the cases above, without the bound, we could not have covariance on
29
43
  # the type because it would break LSP. Note that covariance on the return type
30
44
  # is authorized in LSP hence we do not need the same TypeVar instruction for
31
45
  # the return types of evaluate_by_terms for example!
32
46
 
33
47
 
34
- class AbstractLoss(eqx.Module, Generic[L, B, C]):
48
+ AvailableUpdateWeightMethods = Literal[
49
+ "softadapt", "soft_adapt", "prior_loss", "lr_annealing", "ReLoBRaLo"
50
+ ]
51
+
52
+
53
+ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
35
54
  """
36
55
  About the call:
37
56
  https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
38
57
  """
39
58
 
59
+ derivative_keys: eqx.AbstractVar[DK]
40
60
  loss_weights: eqx.AbstractVar[L]
41
- update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
42
- eqx.field(kw_only=True, default=None, static=True)
61
+ loss_weight_scales: L = eqx.field(init=False)
62
+ update_weight_method: AvailableUpdateWeightMethods | None = eqx.field(
63
+ kw_only=True, default=None, static=True
64
+ )
65
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
66
+ keep_initial_loss_weight_scales: InitVar[bool] = eqx.field(
67
+ default=True, kw_only=True
43
68
  )
44
69
 
70
+ def __init__(
71
+ self,
72
+ *,
73
+ loss_weights,
74
+ derivative_keys,
75
+ vmap_in_axes,
76
+ update_weight_method=None,
77
+ keep_initial_loss_weight_scales: bool = False,
78
+ ):
79
+ if update_weight_method is not None and update_weight_method not in get_args(
80
+ AvailableUpdateWeightMethods
81
+ ):
82
+ raise ValueError(f"{update_weight_method=} is not a valid method")
83
+ self.update_weight_method = update_weight_method
84
+ self.loss_weights = loss_weights
85
+ self.derivative_keys = derivative_keys
86
+ self.vmap_in_axes = vmap_in_axes
87
+ if keep_initial_loss_weight_scales:
88
+ self.loss_weight_scales = self.loss_weights
89
+ if self.update_weight_method is not None:
90
+ warnings.warn(
91
+ "Loss weights out from update_weight_method will still be"
92
+ " multiplied by the initial input loss_weights"
93
+ )
94
+ else:
95
+ self.loss_weight_scales = optax.tree_utils.tree_ones_like(self.loss_weights)
96
+ # self.loss_weight_scales will contain None where self.loss_weights
97
+ # has None
98
+
45
99
  def __call__(self, *args: Any, **kwargs: Any) -> Any:
46
100
  return self.evaluate(*args, **kwargs)
47
101
 
48
102
  @abc.abstractmethod
49
- def evaluate_by_terms(self, params: Params[Array], batch: B) -> tuple[C, C]:
103
+ def evaluate_by_terms(
104
+ self,
105
+ opt_params: Params[Array],
106
+ batch: B,
107
+ *,
108
+ non_opt_params: Params[Array] | None = None,
109
+ ) -> tuple[C, C]:
50
110
  pass
51
111
 
52
- def evaluate(self, params: Params[Array], batch: B) -> tuple[Float[Array, " "], C]:
112
+ def evaluate(
113
+ self,
114
+ opt_params: Params[Array],
115
+ batch: B,
116
+ *,
117
+ non_opt_params: Params[Array] | None = None,
118
+ ) -> tuple[Float[Array, " "], C]:
53
119
  """
54
120
  Evaluate the loss function at a batch of points for given parameters.
55
121
 
@@ -57,16 +123,20 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
57
123
 
58
124
  Parameters
59
125
  ---------
60
- params
61
- Parameters at which the loss is evaluated
126
+ opt_params
127
+ Parameters, which are optimized, at which the loss is evaluated
62
128
  batch
63
129
  Composed of a batch of points in the
64
130
  domain, a batch of points in the domain
65
131
  border and an optional additional batch of parameters (eg. for
66
132
  metamodeling) and an optional additional batch of observed
67
133
  inputs/outputs/parameters
134
+ non_opt_params
135
+ Parameters, which are non optimized, at which the loss is evaluated
68
136
  """
69
- loss_terms, _ = self.evaluate_by_terms(params, batch)
137
+ loss_terms, _ = self.evaluate_by_terms(
138
+ opt_params, batch, non_opt_params=non_opt_params
139
+ )
70
140
 
71
141
  loss_val = self.ponderate_and_sum_loss(loss_terms)
72
142
 
@@ -102,10 +172,14 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
102
172
  raise ValueError(
103
173
  "The numbers of declared loss weights and "
104
174
  "declared loss terms do not concord "
105
- f" got {len(weights)} and {len(terms_list)}"
175
+ f" got {len(weights)} and {len(terms_list)}. "
176
+ "If you passed tuple of dyn_loss, make sure to pass "
177
+ "tuple of loss weights at LossWeights.dyn_loss."
178
+ "If you passed tuple of obs datasets, make sure to pass "
179
+ "tuple of loss weights at LossWeights.observations."
106
180
  )
107
181
 
108
- def ponderate_and_sum_gradient(self, terms: C) -> C:
182
+ def ponderate_and_sum_gradient(self, terms: C) -> Params[Array | None]:
109
183
  """
110
184
  Get total gradients from individual loss gradients and weights
111
185
  for each parameter
@@ -146,6 +220,8 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
146
220
  new_weights = soft_adapt(
147
221
  self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
148
222
  )
223
+ elif self.update_weight_method == "prior_loss":
224
+ new_weights = prior_loss(self.loss_weights, iteration_nb, stored_loss_terms)
149
225
  elif self.update_weight_method == "lr_annealing":
150
226
  new_weights = lr_annealing(self.loss_weights, grad_terms)
151
227
  elif self.update_weight_method == "ReLoBRaLo":
@@ -158,6 +234,13 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
158
234
  # Below we update the non None entry in the PyTree self.loss_weights
159
235
  # we directly get the non None entries because None is not treated as a
160
236
  # leaf
237
+
238
+ new_weights = jax.lax.cond(
239
+ iteration_nb == 0,
240
+ lambda nw: nw,
241
+ lambda nw: jnp.array(jax.tree.leaves(self.loss_weight_scales)) * nw,
242
+ new_weights,
243
+ )
161
244
  return eqx.tree_at(
162
245
  lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
163
246
  )
@@ -227,7 +227,7 @@ def boundary_neumann(
227
227
  if isinstance(u, PINN):
228
228
  u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
229
229
 
230
- if u.eq_type == "statio_PDE":
230
+ if u.eq_type == "PDEStatio":
231
231
  v_neumann = vmap(
232
232
  lambda inputs, params: _subtract_with_check(
233
233
  f(inputs),
@@ -240,7 +240,7 @@ def boundary_neumann(
240
240
  vmap_in_axes,
241
241
  0,
242
242
  )
243
- elif u.eq_type == "nonstatio_PDE":
243
+ elif u.eq_type == "PDENonStatio":
244
244
  v_neumann = vmap(
245
245
  lambda inputs, params: _subtract_with_check(
246
246
  f(inputs),
@@ -274,14 +274,14 @@ def boundary_neumann(
274
274
  if (batch_array.shape[0] == 1 and isinstance(batch, PDEStatioBatch)) or (
275
275
  batch_array.shape[-1] == 2 and isinstance(batch, PDENonStatioBatch)
276
276
  ):
277
- if u.eq_type == "statio_PDE":
277
+ if u.eq_type == "PDEStatio":
278
278
  _, du_dx = jax.jvp(
279
279
  lambda inputs: u(inputs, params)[..., dim_to_apply],
280
280
  (batch_array,),
281
281
  (jnp.ones_like(batch_array),),
282
282
  )
283
283
  values = du_dx * n[facet]
284
- if u.eq_type == "nonstatio_PDE":
284
+ if u.eq_type == "PDENonStatio":
285
285
  _, du_dx = jax.jvp(
286
286
  lambda inputs: u(inputs, params)[..., dim_to_apply],
287
287
  (batch_array,),
@@ -291,7 +291,7 @@ def boundary_neumann(
291
291
  elif (batch_array.shape[-1] == 2 and isinstance(batch, PDEStatioBatch)) or (
292
292
  batch_array.shape[-1] == 3 and isinstance(batch, PDENonStatioBatch)
293
293
  ):
294
- if u.eq_type == "statio_PDE":
294
+ if u.eq_type == "PDEStatio":
295
295
  tangent_vec_0 = jnp.repeat(
296
296
  jnp.array([1.0, 0.0])[None], batch_array.shape[0], axis=0
297
297
  )
@@ -309,7 +309,7 @@ def boundary_neumann(
309
309
  (tangent_vec_1,),
310
310
  )
311
311
  values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
312
- if u.eq_type == "nonstatio_PDE":
312
+ if u.eq_type == "PDENonStatio":
313
313
  tangent_vec_0 = jnp.repeat(
314
314
  jnp.array([0.0, 1.0, 0.0])[None], batch_array.shape[0], axis=0
315
315
  )
jinns/loss/_loss_utils.py CHANGED
@@ -26,12 +26,12 @@ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
26
26
  from jinns.parameters._params import Params
27
27
 
28
28
  if TYPE_CHECKING:
29
- from jinns.utils._types import BoundaryConditionFun
29
+ from jinns.utils._types import BoundaryConditionFun, AnyBatch
30
30
  from jinns.nn._abstract_pinn import AbstractPINN
31
31
 
32
32
 
33
33
  def dynamic_loss_apply(
34
- dyn_loss: Callable,
34
+ dyn_loss: Callable[[AnyBatch, AbstractPINN, Params[Array]], Array],
35
35
  u: AbstractPINN,
36
36
  batch: (
37
37
  Float[Array, " batch_size 1"]
@@ -13,6 +13,36 @@ if TYPE_CHECKING:
13
13
  from jinns.utils._types import AnyLossComponents, AnyLossWeights
14
14
 
15
15
 
16
+ def prior_loss(
17
+ loss_weights: AnyLossWeights,
18
+ iteration_nb: int,
19
+ stored_loss_terms: AnyLossComponents,
20
+ ) -> Array:
21
+ """
22
+ Simple adaptative weights according to the prior loss idea:
23
+ the ponderation in front of a loss term is given by the inverse of the
24
+ value of that loss term at the previous iteration
25
+ """
26
+
27
+ def do_nothing(loss_weights, _):
28
+ return jnp.array(
29
+ jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
30
+ )
31
+
32
+ def _prior_loss(_, stored_loss_terms):
33
+ new_weights = jax.tree.map(
34
+ lambda slt: 1 / (slt[iteration_nb - 1] + 1e-6), stored_loss_terms
35
+ )
36
+ return jnp.array(jax.tree.leaves(new_weights), dtype=float)
37
+
38
+ return jax.lax.cond(
39
+ iteration_nb == 0,
40
+ lambda op: do_nothing(*op),
41
+ lambda op: _prior_loss(*op),
42
+ (loss_weights, stored_loss_terms),
43
+ )
44
+
45
+
16
46
  def soft_adapt(
17
47
  loss_weights: AnyLossWeights,
18
48
  iteration_nb: int,
@@ -18,6 +18,10 @@ from jinns.loss._loss_components import (
18
18
  def lw_converter(x: Array | None) -> Array | None:
19
19
  if x is None:
20
20
  return x
21
+ elif isinstance(x, tuple):
22
+ # user might input tuple of scalar loss weights to account for cases
23
+ # when dyn loss is also a tuple of (possibly 1D) dyn_loss
24
+ return tuple(jnp.asarray(x_) for x_ in x)
21
25
  else:
22
26
  return jnp.asarray(x)
23
27
 
jinns/loss/_operators.py CHANGED
@@ -18,8 +18,8 @@ from jinns.nn._abstract_pinn import AbstractPINN
18
18
 
19
19
  def _get_eq_type(
20
20
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
21
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None,
22
- ) -> Literal["nonstatio_PDE", "statio_PDE"]:
21
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None,
22
+ ) -> Literal["PDENonStatio", "PDEStatio"]:
23
23
  """
24
24
  But we filter out ODE from eq_type because we only have operators that does
25
25
  not work with ODEs so far
@@ -36,7 +36,7 @@ def divergence_rev(
36
36
  inputs: Float[Array, " dim"] | Float[Array, " 1+dim"],
37
37
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
38
38
  params: Params[Array],
39
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
39
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
40
40
  ) -> Float[Array, " "]:
41
41
  r"""
42
42
  Compute the divergence of a vector field $\mathbf{u}$, i.e.,
@@ -64,7 +64,7 @@ def divergence_rev(
64
64
  eq_type = _get_eq_type(u, eq_type)
65
65
 
66
66
  def scan_fun(_, i):
67
- if eq_type == "nonstatio_PDE":
67
+ if eq_type == "PDENonStatio":
68
68
  du_dxi = grad(lambda inputs, params: u(inputs, params)[1 + i])(
69
69
  inputs, params
70
70
  )[1 + i]
@@ -74,9 +74,9 @@ def divergence_rev(
74
74
  ]
75
75
  return _, du_dxi
76
76
 
77
- if eq_type == "nonstatio_PDE":
77
+ if eq_type == "PDENonStatio":
78
78
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0] - 1))
79
- elif eq_type == "statio_PDE":
79
+ elif eq_type == "PDEStatio":
80
80
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
81
81
  else:
82
82
  raise ValueError("Unexpected u.eq_type!")
@@ -87,7 +87,7 @@ def divergence_fwd(
87
87
  inputs: Float[Array, " batch_size dim"] | Float[Array, " batch_size 1+dim"],
88
88
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
89
89
  params: Params[Array],
90
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
90
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
91
91
  ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
92
92
  r"""
93
93
  Compute the divergence of a **batched** vector field $\mathbf{u}$, i.e.,
@@ -120,7 +120,7 @@ def divergence_fwd(
120
120
  eq_type = _get_eq_type(u, eq_type)
121
121
 
122
122
  def scan_fun(_, i):
123
- if eq_type == "nonstatio_PDE":
123
+ if eq_type == "PDENonStatio":
124
124
  tangent_vec = jnp.repeat(
125
125
  jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
126
126
  inputs.shape[0],
@@ -140,9 +140,9 @@ def divergence_fwd(
140
140
  )
141
141
  return _, du_dxi
142
142
 
143
- if eq_type == "nonstatio_PDE":
143
+ if eq_type == "PDENonStatio":
144
144
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1] - 1))
145
- elif eq_type == "statio_PDE":
145
+ elif eq_type == "PDEStatio":
146
146
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[1]))
147
147
  else:
148
148
  raise ValueError("Unexpected u.eq_type!")
@@ -154,7 +154,7 @@ def laplacian_rev(
154
154
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
155
155
  params: Params[Array],
156
156
  method: Literal["trace_hessian_x", "trace_hessian_t_x", "loop"] = "trace_hessian_x",
157
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
157
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
158
158
  ) -> Float[Array, " "]:
159
159
  r"""
160
160
  Compute the Laplacian of a scalar field $u$ from $\mathbb{R}^d$
@@ -196,22 +196,22 @@ def laplacian_rev(
196
196
  # computation and then discarding elements but for higher order derivatives
197
197
  # it might not be worth it. See other options below for computating the
198
198
  # Laplacian
199
- if eq_type == "nonstatio_PDE":
199
+ if eq_type == "PDENonStatio":
200
200
  u_ = lambda x: jnp.squeeze(
201
201
  u(jnp.concatenate([inputs[:1], x], axis=0), params)
202
202
  )
203
203
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs[1:])))
204
- if eq_type == "statio_PDE":
204
+ if eq_type == "PDEStatio":
205
205
  u_ = lambda inputs: jnp.squeeze(u(inputs, params))
206
206
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
207
207
  raise ValueError("Unexpected eq_type!")
208
208
  if method == "trace_hessian_t_x":
209
209
  # NOTE that it is unclear whether it is better to vectorially compute the
210
210
  # Hessian (despite a useless time dimension) as below
211
- if eq_type == "nonstatio_PDE":
211
+ if eq_type == "PDENonStatio":
212
212
  u_ = lambda inputs: jnp.squeeze(u(inputs, params))
213
213
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs))[1:])
214
- if eq_type == "statio_PDE":
214
+ if eq_type == "PDEStatio":
215
215
  u_ = lambda inputs: jnp.squeeze(u(inputs, params))
216
216
  return jnp.sum(jnp.diag(jax.hessian(u_)(inputs)))
217
217
  raise ValueError("Unexpected eq_type!")
@@ -225,7 +225,7 @@ def laplacian_rev(
225
225
  u_ = lambda inputs: u(inputs, params).squeeze()
226
226
 
227
227
  def scan_fun(_, i):
228
- if eq_type == "nonstatio_PDE":
228
+ if eq_type == "PDENonStatio":
229
229
  d2u_dxi2 = grad(
230
230
  lambda inputs: grad(u_)(inputs)[1 + i],
231
231
  )(inputs)[1 + i]
@@ -236,11 +236,11 @@ def laplacian_rev(
236
236
  )(inputs)[i]
237
237
  return _, d2u_dxi2
238
238
 
239
- if eq_type == "nonstatio_PDE":
239
+ if eq_type == "PDENonStatio":
240
240
  _, trace_hessian = jax.lax.scan(
241
241
  scan_fun, {}, jnp.arange(inputs.shape[0] - 1)
242
242
  )
243
- elif eq_type == "statio_PDE":
243
+ elif eq_type == "PDEStatio":
244
244
  _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[0]))
245
245
  else:
246
246
  raise ValueError("Unexpected eq_type!")
@@ -253,7 +253,7 @@ def laplacian_fwd(
253
253
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
254
254
  params: Params[Array],
255
255
  method: Literal["trace_hessian_t_x", "trace_hessian_x", "loop"] = "loop",
256
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
256
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
257
257
  ) -> Float[Array, " batch_size * (1+dim) 1"] | Float[Array, " batch_size * (dim) 1"]:
258
258
  r"""
259
259
  Compute the Laplacian of a **batched** scalar field $u$
@@ -302,7 +302,7 @@ def laplacian_fwd(
302
302
  if method == "loop":
303
303
 
304
304
  def scan_fun(_, i):
305
- if eq_type == "nonstatio_PDE":
305
+ if eq_type == "PDENonStatio":
306
306
  tangent_vec = jnp.repeat(
307
307
  jax.nn.one_hot(i + 1, inputs.shape[-1])[None],
308
308
  inputs.shape[0],
@@ -323,17 +323,17 @@ def laplacian_fwd(
323
323
  __, d2u_dxi2 = jax.jvp(du_dxi_fun, (inputs,), (tangent_vec,))
324
324
  return _, d2u_dxi2
325
325
 
326
- if eq_type == "nonstatio_PDE":
326
+ if eq_type == "PDENonStatio":
327
327
  _, trace_hessian = jax.lax.scan(
328
328
  scan_fun, {}, jnp.arange(inputs.shape[-1] - 1)
329
329
  )
330
- elif eq_type == "statio_PDE":
330
+ elif eq_type == "PDEStatio":
331
331
  _, trace_hessian = jax.lax.scan(scan_fun, {}, jnp.arange(inputs.shape[-1]))
332
332
  else:
333
333
  raise ValueError("Unexpected eq_type!")
334
334
  return jnp.sum(trace_hessian, axis=0)
335
335
  if method == "trace_hessian_t_x":
336
- if eq_type == "nonstatio_PDE":
336
+ if eq_type == "PDENonStatio":
337
337
  # compute the Hessian including the batch dimension, get rid of the
338
338
  # (..,1,..) axis that is here because of the scalar output
339
339
  # if inputs.shape==(10,3) (1 for time, 2 for x_dim)
@@ -351,7 +351,7 @@ def laplacian_fwd(
351
351
  res_dims = "".join([f"{chr(97 + d)}" for d in range(inputs.shape[-1])])
352
352
  lap = jnp.einsum(res_dims + "ii->" + res_dims, r)
353
353
  return lap[..., None]
354
- if eq_type == "statio_PDE":
354
+ if eq_type == "PDEStatio":
355
355
  # compute the Hessian including the batch dimension, get rid of the
356
356
  # (..,1,..) axis that is here because of the scalar output
357
357
  # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
@@ -369,7 +369,7 @@ def laplacian_fwd(
369
369
  return lap[..., None]
370
370
  raise ValueError("Unexpected eq_type!")
371
371
  if method == "trace_hessian_x":
372
- if eq_type == "statio_PDE":
372
+ if eq_type == "PDEStatio":
373
373
  # compute the Hessian including the batch dimension, get rid of the
374
374
  # (..,1,..) axis that is here because of the scalar output
375
375
  # if inputs.shape==(10,2), r.shape=(10,10,1,10,2,10,2)
@@ -394,7 +394,7 @@ def vectorial_laplacian_rev(
394
394
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
395
395
  params: Params[Array],
396
396
  dim_out: int | None = None,
397
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
397
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
398
398
  ) -> Float[Array, " dim_out"]:
399
399
  r"""
400
400
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ from
@@ -448,7 +448,7 @@ def vectorial_laplacian_fwd(
448
448
  u: AbstractPINN | Callable[[Array, Params[Array]], Array],
449
449
  params: Params[Array],
450
450
  dim_out: int | None = None,
451
- eq_type: Literal["nonstatio_PDE", "statio_PDE"] | None = None,
451
+ eq_type: Literal["PDENonStatio", "PDEStatio"] | None = None,
452
452
  ) -> Float[Array, " batch_size * (1+dim) n"] | Float[Array, " batch_size * (dim) n"]:
453
453
  r"""
454
454
  Compute the vectorial Laplacian of a vector field $\mathbf{u}$ when
@@ -13,7 +13,7 @@ class AbstractPINN(eqx.Module):
13
13
  https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
14
14
  """
15
15
 
16
- eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
16
+ eq_type: eqx.AbstractVar[Literal["ODE", "PDEStatio", "PDENonStatio"]]
17
17
 
18
18
  @abc.abstractmethod
19
19
  def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
jinns/nn/_hyperpinn.py CHANGED
@@ -67,9 +67,9 @@ class HyperPINN(PINN):
67
67
  eq_type : str
68
68
  A string with three possibilities.
69
69
  "ODE": the HyperPINN is called with one input `t`.
70
- "statio_PDE": the HyperPINN is called with one input `x`, `x`
70
+ "PDEStatio": the HyperPINN is called with one input `x`, `x`
71
71
  can be high dimensional.
72
- "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
72
+ "PDENonStatio": the HyperPINN is called with two inputs `t` and `x`, `x`
73
73
  can be high dimensional.
74
74
  **Note**: the input dimension as given in eqx_list has to match the sum
75
75
  of the dimension of `t` + the dimension of `x` or the output dimension
@@ -192,7 +192,7 @@ class HyperPINN(PINN):
192
192
  hyper = eqx.combine(params.nn_params, self.static_hyper)
193
193
 
194
194
  eq_params_batch = jnp.concatenate(
195
- [getattr(params.eq_params, k).flatten() for k in self.hyperparams],
195
+ [getattr(params.eq_params, k).flatten() for k in self.hyperparams], # pylint: disable=E1133
196
196
  axis=0,
197
197
  )
198
198
 
@@ -214,7 +214,7 @@ class HyperPINN(PINN):
214
214
  def create(
215
215
  cls,
216
216
  *,
217
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
217
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
218
218
  hyperparams: list[str],
219
219
  hypernet_input_size: int,
220
220
  key: PRNGKeyArray | None = None,
@@ -257,9 +257,9 @@ class HyperPINN(PINN):
257
257
  eq_type
258
258
  A string with three possibilities.
259
259
  "ODE": the HyperPINN is called with one input `t`.
260
- "statio_PDE": the HyperPINN is called with one input `x`, `x`
260
+ "PDEStatio": the HyperPINN is called with one input `x`, `x`
261
261
  can be high dimensional.
262
- "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
262
+ "PDENonStatio": the HyperPINN is called with two inputs `t` and `x`, `x`
263
263
  can be high dimensional.
264
264
  **Note**: the input dimension as given in eqx_list has to match the sum
265
265
  of the dimension of `t` + the dimension of `x` or the output dimension
jinns/nn/_mlp.py CHANGED
@@ -95,7 +95,7 @@ class PINN_MLP(PINN):
95
95
  def create(
96
96
  cls,
97
97
  *,
98
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
98
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
99
99
  key: PRNGKeyArray | None = None,
100
100
  eqx_network: eqx.nn.MLP | MLP | None = None,
101
101
  eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
@@ -130,9 +130,9 @@ class PINN_MLP(PINN):
130
130
  eq_type
131
131
  A string with three possibilities.
132
132
  "ODE": the MLP is called with one input `t`.
133
- "statio_PDE": the MLP is called with one input `x`, `x`
133
+ "PDEStatio": the MLP is called with one input `x`, `x`
134
134
  can be high dimensional.
135
- "nonstatio_PDE": the MLP is called with two inputs `t` and `x`, `x`
135
+ "PDENonStatio": the MLP is called with two inputs `t` and `x`, `x`
136
136
  can be high dimensional.
137
137
  **Note**: the input dimension as given in eqx_list has to match the sum
138
138
  of the dimension of `t` + the dimension of `x` or the output dimension
jinns/nn/_pinn.py CHANGED
@@ -50,12 +50,12 @@ class PINN(AbstractPINN):
50
50
  when the PINN is also used to output equation parameters for example
51
51
  Note that it must be a slice and not an integer (a preprocessing of the
52
52
  user provided argument takes care of it).
53
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
53
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
54
54
  A string with three possibilities.
55
55
  "ODE": the PINN is called with one input `t`.
56
- "statio_PDE": the PINN is called with one input `x`, `x`
56
+ "PDEStatio": the PINN is called with one input `x`, `x`
57
57
  can be high dimensional.
58
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
58
+ "PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
59
59
  can be high dimensional.
60
60
  **Note**: the input dimension as given in eqx_list has to match the sum
61
61
  of the dimension of `t` + the dimension of `x` or the output dimension
@@ -83,11 +83,11 @@ class PINN(AbstractPINN):
83
83
  Raises
84
84
  ------
85
85
  RuntimeError
86
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
87
- "nonstatio_PDE"]`
86
+ If the parameter value for eq_type is not in `["ODE", "PDEStatio",
87
+ "PDENonStatio"]`
88
88
  """
89
89
 
90
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
90
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"] = eqx.field(
91
91
  static=True, kw_only=True
92
92
  )
93
93
  slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
@@ -108,7 +108,7 @@ class PINN(AbstractPINN):
108
108
  static: PINN = eqx.field(init=False, static=True)
109
109
 
110
110
  def __post_init__(self, eqx_network):
111
- if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
111
+ if self.eq_type not in ["ODE", "PDEStatio", "PDENonStatio"]:
112
112
  raise RuntimeError("Wrong parameter value for eq_type")
113
113
  # saving the static part of the model and initial parameters
114
114
 
jinns/nn/_ppinn.py CHANGED
@@ -31,12 +31,12 @@ class PPINN_MLP(PINN):
31
31
  when the PINN is also used to output equation parameters for example
32
32
  Note that it must be a slice and not an integer (a preprocessing of the
33
33
  user provided argument takes care of it).
34
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
34
+ eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
35
35
  A string with three possibilities.
36
36
  "ODE": the PPINN is called with one input `t`.
37
- "statio_PDE": the PPINN is called with one input `x`, `x`
37
+ "PDEStatio": the PPINN is called with one input `x`, `x`
38
38
  can be high dimensional.
39
- "nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
39
+ "PDENonStatio": the PPINN is called with two inputs `t` and `x`, `x`
40
40
  can be high dimensional.
41
41
  **Note**: the input dimension as given in eqx_list has to match the sum
42
42
  of the dimension of `t` + the dimension of `x` or the output dimension
@@ -125,7 +125,7 @@ class PPINN_MLP(PINN):
125
125
  cls,
126
126
  *,
127
127
  key: PRNGKeyArray | None = None,
128
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
128
+ eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
129
129
  eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
130
130
  eqx_list_list: (
131
131
  list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
@@ -158,9 +158,9 @@ class PPINN_MLP(PINN):
158
158
  eq_type
159
159
  A string with three possibilities.
160
160
  "ODE": the PPINN MLP is called with one input `t`.
161
- "statio_PDE": the PPINN MLP is called with one input `x`, `x`
161
+ "PDEStatio": the PPINN MLP is called with one input `x`, `x`
162
162
  can be high dimensional.
163
- "nonstatio_PDE": the PPINN MLP is called with two inputs `t` and `x`, `x`
163
+ "PDENonStatio": the PPINN MLP is called with two inputs `t` and `x`, `x`
164
164
  can be high dimensional.
165
165
  **Note**: the input dimension as given in eqx_list has to match the sum
166
166
  of the dimension of `t` + the dimension of `x` or the output dimension