jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,61 +21,32 @@ else:
21
21
  from equinox import AbstractClassVar
22
22
 
23
23
 
24
- def _decorator_heteregeneous_params(evaluate, eq_type):
24
+ def _decorator_heteregeneous_params(evaluate):
25
25
 
26
- def wrapper_ode(*args):
27
- self, t, u, params = args
26
+ def wrapper(*args):
27
+ self, inputs, u, params = args
28
28
  _params = eqx.tree_at(
29
29
  lambda p: p.eq_params,
30
30
  params,
31
31
  self._eval_heterogeneous_parameters(
32
- t, None, u, params, self.eq_params_heterogeneity
32
+ inputs, u, params, self.eq_params_heterogeneity
33
33
  ),
34
34
  )
35
35
  new_args = args[:-1] + (_params,)
36
36
  res = evaluate(*new_args)
37
37
  return res
38
38
 
39
- def wrapper_pde_statio(*args):
40
- self, x, u, params = args
41
- _params = eqx.tree_at(
42
- lambda p: p.eq_params,
43
- params,
44
- self._eval_heterogeneous_parameters(
45
- None, x, u, params, self.eq_params_heterogeneity
46
- ),
47
- )
48
- new_args = args[:-1] + (_params,)
49
- res = evaluate(*new_args)
50
- return res
51
-
52
- def wrapper_pde_non_statio(*args):
53
- self, t, x, u, params = args
54
- _params = eqx.tree_at(
55
- lambda p: p.eq_params,
56
- params,
57
- self._eval_heterogeneous_parameters(
58
- t, x, u, params, self.eq_params_heterogeneity
59
- ),
60
- )
61
- new_args = args[:-1] + (_params,)
62
- res = evaluate(*new_args)
63
- return res
64
-
65
- if eq_type == "ODE":
66
- return wrapper_ode
67
- elif eq_type == "Statio PDE":
68
- return wrapper_pde_statio
69
- elif eq_type == "Non-statio PDE":
70
- return wrapper_pde_non_statio
39
+ return wrapper
71
40
 
72
41
 
73
42
  class DynamicLoss(eqx.Module):
74
43
  r"""
75
44
  Abstract base class for dynamic losses. Implements the physical term:
45
+
76
46
  $$
77
47
  \mathcal{N}[u](t, x) = 0
78
48
  $$
49
+
79
50
  for **one** point $t$, $x$ or $(t, x)$, depending on the context.
80
51
 
81
52
  Parameters
@@ -108,8 +79,7 @@ class DynamicLoss(eqx.Module):
108
79
 
109
80
  def _eval_heterogeneous_parameters(
110
81
  self,
111
- t: Float[Array, "1"],
112
- x: Float[Array, "dim"],
82
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
113
83
  u: eqx.Module,
114
84
  params: Params | ParamsDict,
115
85
  eq_params_heterogeneity: Dict[str, Callable | None] = None,
@@ -122,14 +92,7 @@ class DynamicLoss(eqx.Module):
122
92
  if eq_params_heterogeneity[k] is None:
123
93
  eq_params_[k] = p
124
94
  else:
125
- # heterogeneity encoded through a function whose
126
- # signature will vary according to _eq_type
127
- if self._eq_type == "ODE":
128
- eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
129
- elif self._eq_type == "Statio PDE":
130
- eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
131
- elif self._eq_type == "Non-statio PDE":
132
- eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
95
+ eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params)
133
96
  except KeyError:
134
97
  # we authorize missing eq_params_heterogeneity key
135
98
  # if its heterogeneity is None anyway
@@ -138,22 +101,17 @@ class DynamicLoss(eqx.Module):
138
101
 
139
102
  def _evaluate(
140
103
  self,
141
- t: Float[Array, "1"],
142
- x: Float[Array, "dim"],
104
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
143
105
  u: eqx.Module,
144
106
  params: Params | ParamsDict,
145
107
  ) -> float:
146
- # Here we handle the various possible signature
147
- if self._eq_type == "ODE":
148
- ans = self.equation(t, u, params)
149
- elif self._eq_type == "Statio PDE":
150
- ans = self.equation(x, u, params)
151
- elif self._eq_type == "Non-statio PDE":
152
- ans = self.equation(t, x, u, params)
153
- else:
154
- raise NotImplementedError("the equation type is not handled.")
155
-
156
- return ans
108
+ evaluation = self.equation(inputs, u, params)
109
+ if len(evaluation.shape) == 0:
110
+ raise ValueError(
111
+ "The output of dynamic loss must be vectorial, "
112
+ "i.e. of shape (d,) with d >= 1"
113
+ )
114
+ return evaluation
157
115
 
158
116
  @abc.abstractmethod
159
117
  def equation(self, *args, **kwargs):
@@ -189,7 +147,7 @@ class ODE(DynamicLoss):
189
147
 
190
148
  _eq_type: ClassVar[str] = "ODE"
191
149
 
192
- @partial(_decorator_heteregeneous_params, eq_type="ODE")
150
+ @partial(_decorator_heteregeneous_params)
193
151
  def evaluate(
194
152
  self,
195
153
  t: Float[Array, "1"],
@@ -197,7 +155,7 @@ class ODE(DynamicLoss):
197
155
  params: Params | ParamsDict,
198
156
  ) -> float:
199
157
  """Here we call DynamicLoss._evaluate with x=None"""
200
- return self._evaluate(t, None, u, params)
158
+ return self._evaluate(t, u, params)
201
159
 
202
160
  @abc.abstractmethod
203
161
  def equation(
@@ -258,12 +216,12 @@ class PDEStatio(DynamicLoss):
258
216
 
259
217
  _eq_type: ClassVar[str] = "Statio PDE"
260
218
 
261
- @partial(_decorator_heteregeneous_params, eq_type="Statio PDE")
219
+ @partial(_decorator_heteregeneous_params)
262
220
  def evaluate(
263
221
  self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
264
222
  ) -> float:
265
223
  """Here we call the DynamicLoss._evaluate with t=None"""
266
- return self._evaluate(None, x, u, params)
224
+ return self._evaluate(x, u, params)
267
225
 
268
226
  @abc.abstractmethod
269
227
  def equation(
@@ -323,23 +281,21 @@ class PDENonStatio(DynamicLoss):
323
281
 
324
282
  _eq_type: ClassVar[str] = "Non-statio PDE"
325
283
 
326
- @partial(_decorator_heteregeneous_params, eq_type="Non-statio PDE")
284
+ @partial(_decorator_heteregeneous_params)
327
285
  def evaluate(
328
286
  self,
329
- t: Float[Array, "1"],
330
- x: Float[Array, "dim"],
287
+ t_x: Float[Array, "1 + dim"],
331
288
  u: eqx.Module,
332
289
  params: Params | ParamsDict,
333
290
  ) -> float:
334
291
  """Here we call the DynamicLoss._evaluate with full arguments"""
335
- ans = self._evaluate(t, x, u, params)
292
+ ans = self._evaluate(t_x, u, params)
336
293
  return ans
337
294
 
338
295
  @abc.abstractmethod
339
296
  def equation(
340
297
  self,
341
- t: Float[Array, "1"],
342
- x: Float[Array, "dim"],
298
+ t_x: Float[Array, "1 + dim"],
343
299
  u: eqx.Module,
344
300
  params: Params | ParamsDict,
345
301
  ) -> float:
@@ -351,10 +307,8 @@ class PDENonStatio(DynamicLoss):
351
307
 
352
308
  Parameters
353
309
  ----------
354
- t : Float[Array, "1"]
355
- A 1-dimensional jnp.array representing the time point.
356
- x : Float[Array, "d"]
357
- A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
310
+ t_x : Float[Array, "1 + dim"]
311
+ A jnp array containing the concatenation of a time point and a point in $\Omega$
358
312
  u : eqx.Module
359
313
  The neural network.
360
314
  params : Params | ParamsDict
jinns/loss/_LossODE.py CHANGED
@@ -56,6 +56,9 @@ class _LossODEAbstract(eqx.Module):
56
56
  slice of u output(s) that is observed. This is useful for
57
57
  multidimensional PINN, with partially observed outputs.
58
58
  Default is None (whole output is observed).
59
+ params : InitVar[Params], default=None
60
+ The main Params object of the problem needed to instanciate the
61
+ DerivativeKeysODE if the latter is not specified.
59
62
  """
60
63
 
61
64
  # NOTE static=True only for leaf attributes that are not valid JAX types
@@ -66,13 +69,21 @@ class _LossODEAbstract(eqx.Module):
66
69
  initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
67
70
  obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
68
71
 
69
- def __post_init__(self):
72
+ params: InitVar[Params] = eqx.field(default=None, kw_only=True)
73
+
74
+ def __post_init__(self, params=None):
70
75
  if self.loss_weights is None:
71
76
  self.loss_weights = LossWeightsODE()
72
77
 
73
78
  if self.derivative_keys is None:
74
- # be default we only take gradient wrt nn_params
75
- self.derivative_keys = DerivativeKeysODE()
79
+ try:
80
+ # be default we only take gradient wrt nn_params
81
+ self.derivative_keys = DerivativeKeysODE(params=params)
82
+ except ValueError as exc:
83
+ raise ValueError(
84
+ "Problem at self.derivative_keys initialization "
85
+ f"received {self.derivative_keys=} and {params=}"
86
+ ) from exc
76
87
  if self.initial_condition is None:
77
88
  warnings.warn(
78
89
  "Initial condition wasn't provided. Be sure to cover for that"
@@ -131,6 +142,9 @@ class LossODE(_LossODEAbstract):
131
142
  slice of u output(s) that is observed. This is useful for
132
143
  multidimensional PINN, with partially observed outputs.
133
144
  Default is None (whole output is observed).
145
+ params : InitVar[Params], default=None
146
+ The main Params object of the problem needed to instanciate the
147
+ DerivativeKeysODE if the latter is not specified.
134
148
  u : eqx.Module
135
149
  the PINN
136
150
  dynamic_loss : DynamicLoss
@@ -152,8 +166,10 @@ class LossODE(_LossODEAbstract):
152
166
 
153
167
  vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
154
168
 
155
- def __post_init__(self):
156
- super().__post_init__() # because __init__ or __post_init__ of Base
169
+ def __post_init__(self, params=None):
170
+ super().__post_init__(
171
+ params=params
172
+ ) # because __init__ or __post_init__ of Base
157
173
  # class is not automatically called
158
174
 
159
175
  self.vmap_in_axes = (0,)
@@ -193,7 +209,7 @@ class LossODE(_LossODEAbstract):
193
209
  mse_dyn_loss = dynamic_loss_apply(
194
210
  self.dynamic_loss.evaluate,
195
211
  self.u,
196
- (temporal_batch,),
212
+ temporal_batch,
197
213
  _set_derivatives(params, self.derivative_keys.dyn_loss),
198
214
  self.vmap_in_axes + vmap_in_axes_params,
199
215
  self.loss_weights.dyn_loss,
@@ -211,7 +227,7 @@ class LossODE(_LossODEAbstract):
211
227
  else:
212
228
  v_u = vmap(self.u, (None,) + vmap_in_axes_params)
213
229
  t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
214
- t0 = jnp.array(t0)
230
+ t0 = jnp.array([t0])
215
231
  u0 = jnp.array(u0)
216
232
  mse_initial_condition = jnp.mean(
217
233
  self.loss_weights.initial_condition
@@ -300,6 +316,9 @@ class SystemLossODE(eqx.Module):
300
316
  PINNs. Default is None. But if a value is given, all the entries of
301
317
  `u_dict` must be represented here with default value `jnp.s_[...]`
302
318
  if no particular slice is to be given.
319
+ params_dict : InitVar[ParamsDict], default=None
320
+ The main Params object of the problem needed to instanciate the
321
+ DerivativeKeysODE if the latter is not specified.
303
322
 
304
323
  Raises
305
324
  ------
@@ -332,14 +351,16 @@ class SystemLossODE(eqx.Module):
332
351
  loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
333
352
  kw_only=True, default=None
334
353
  )
354
+ params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
355
+
335
356
  u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
336
- derivative_keys_dyn_loss_dict: Dict[str, DerivativeKeysODE] = eqx.field(init=False)
357
+ derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
337
358
 
338
359
  u_dict_with_none: Dict[str, None] = eqx.field(init=False)
339
360
  # internally the loss weights are handled with a dictionary
340
361
  _loss_weights: Dict[str, dict] = eqx.field(init=False)
341
362
 
342
- def __post_init__(self, loss_weights):
363
+ def __post_init__(self, loss_weights=None, params_dict=None):
343
364
  # a dictionary that will be useful at different places
344
365
  self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
345
366
  if self.initial_condition_dict is None:
@@ -369,14 +390,14 @@ class SystemLossODE(eqx.Module):
369
390
  # iterating on dynamic_loss_dict. So each time we will require dome
370
391
  # derivative_keys_dict
371
392
 
372
- # but then if the user did not provide anything, we must at least have
373
- # a default value for the dynamic_loss_dict keys entries in
374
- # self.derivative_keys_dict since the computation of dynamic losses is
375
- # made without create a lossODE object that would provide the
376
- # default values
377
- for k in self.dynamic_loss_dict.keys():
393
+ # derivative keys for the u_constraints. Note that we create missing
394
+ # DerivativeKeysODE around a Params object and not ParamsDict
395
+ # this works because u_dict.keys == params_dict.nn_params.keys()
396
+ for k in self.u_dict.keys():
378
397
  if self.derivative_keys_dict[k] is None:
379
- self.derivative_keys_dict[k] = DerivativeKeysODE()
398
+ self.derivative_keys_dict[k] = DerivativeKeysODE(
399
+ params=params_dict.extract_params(k)
400
+ )
380
401
 
381
402
  self._loss_weights = self.set_loss_weights(loss_weights)
382
403
 
@@ -397,12 +418,11 @@ class SystemLossODE(eqx.Module):
397
418
  obs_slice=self.obs_slice_dict[i],
398
419
  )
399
420
 
400
- # for convenience in the tree_map of evaluate
401
- self.derivative_keys_dyn_loss_dict = {
402
- k: self.derivative_keys_dict[k]
403
- for k in self.dynamic_loss_dict.keys() # & self.derivative_keys_dict.keys()
404
- # comment because intersection is neceserily fulfilled right?
405
- }
421
+ # derivative keys for the dynamic loss. Note that we create a
422
+ # DerivativeKeysODE around a ParamsDict object because a whole
423
+ # params_dict is feed to DynamicLoss.evaluate functions (extract_params
424
+ # happen inside it)
425
+ self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
406
426
 
407
427
  def set_loss_weights(self, loss_weights_init):
408
428
  """
@@ -497,13 +517,13 @@ class SystemLossODE(eqx.Module):
497
517
  batch.param_batch_dict, params_dict
498
518
  )
499
519
 
500
- def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
520
+ def dyn_loss_for_one_key(dyn_loss, loss_weight):
501
521
  """This function is used in tree_map"""
502
522
  return dynamic_loss_apply(
503
523
  dyn_loss.evaluate,
504
524
  self.u_dict,
505
- (temporal_batch,),
506
- _set_derivatives(params_dict, derivative_key.dyn_loss),
525
+ temporal_batch,
526
+ _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
507
527
  vmap_in_axes_t + vmap_in_axes_params,
508
528
  loss_weight,
509
529
  u_type=PINN,
@@ -512,7 +532,6 @@ class SystemLossODE(eqx.Module):
512
532
  dyn_loss_mse_dict = jax.tree_util.tree_map(
513
533
  dyn_loss_for_one_key,
514
534
  self.dynamic_loss_dict,
515
- self.derivative_keys_dyn_loss_dict,
516
535
  self._loss_weights["dyn_loss"],
517
536
  is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
518
537
  # where plain (unregister pytree) node classes, we could not traverse