jinns 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,53 +21,22 @@ else:
21
21
  from equinox import AbstractClassVar
22
22
 
23
23
 
24
- def _decorator_heteregeneous_params(evaluate, eq_type):
24
+ def _decorator_heteregeneous_params(evaluate):
25
25
 
26
- def wrapper_ode(*args):
27
- self, t, u, params = args
26
+ def wrapper(*args):
27
+ self, inputs, u, params = args
28
28
  _params = eqx.tree_at(
29
29
  lambda p: p.eq_params,
30
30
  params,
31
31
  self._eval_heterogeneous_parameters(
32
- t, None, u, params, self.eq_params_heterogeneity
32
+ inputs, u, params, self.eq_params_heterogeneity
33
33
  ),
34
34
  )
35
35
  new_args = args[:-1] + (_params,)
36
36
  res = evaluate(*new_args)
37
37
  return res
38
38
 
39
- def wrapper_pde_statio(*args):
40
- self, x, u, params = args
41
- _params = eqx.tree_at(
42
- lambda p: p.eq_params,
43
- params,
44
- self._eval_heterogeneous_parameters(
45
- None, x, u, params, self.eq_params_heterogeneity
46
- ),
47
- )
48
- new_args = args[:-1] + (_params,)
49
- res = evaluate(*new_args)
50
- return res
51
-
52
- def wrapper_pde_non_statio(*args):
53
- self, t, x, u, params = args
54
- _params = eqx.tree_at(
55
- lambda p: p.eq_params,
56
- params,
57
- self._eval_heterogeneous_parameters(
58
- t, x, u, params, self.eq_params_heterogeneity
59
- ),
60
- )
61
- new_args = args[:-1] + (_params,)
62
- res = evaluate(*new_args)
63
- return res
64
-
65
- if eq_type == "ODE":
66
- return wrapper_ode
67
- elif eq_type == "Statio PDE":
68
- return wrapper_pde_statio
69
- elif eq_type == "Non-statio PDE":
70
- return wrapper_pde_non_statio
39
+ return wrapper
71
40
 
72
41
 
73
42
  class DynamicLoss(eqx.Module):
@@ -110,8 +79,7 @@ class DynamicLoss(eqx.Module):
110
79
 
111
80
  def _eval_heterogeneous_parameters(
112
81
  self,
113
- t: Float[Array, "1"],
114
- x: Float[Array, "dim"],
82
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
115
83
  u: eqx.Module,
116
84
  params: Params | ParamsDict,
117
85
  eq_params_heterogeneity: Dict[str, Callable | None] = None,
@@ -124,14 +92,7 @@ class DynamicLoss(eqx.Module):
124
92
  if eq_params_heterogeneity[k] is None:
125
93
  eq_params_[k] = p
126
94
  else:
127
- # heterogeneity encoded through a function whose
128
- # signature will vary according to _eq_type
129
- if self._eq_type == "ODE":
130
- eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
131
- elif self._eq_type == "Statio PDE":
132
- eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
133
- elif self._eq_type == "Non-statio PDE":
134
- eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
95
+ eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params)
135
96
  except KeyError:
136
97
  # we authorize missing eq_params_heterogeneity key
137
98
  # if its heterogeneity is None anyway
@@ -140,22 +101,17 @@ class DynamicLoss(eqx.Module):
140
101
 
141
102
  def _evaluate(
142
103
  self,
143
- t: Float[Array, "1"],
144
- x: Float[Array, "dim"],
104
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
145
105
  u: eqx.Module,
146
106
  params: Params | ParamsDict,
147
107
  ) -> float:
148
- # Here we handle the various possible signature
149
- if self._eq_type == "ODE":
150
- ans = self.equation(t, u, params)
151
- elif self._eq_type == "Statio PDE":
152
- ans = self.equation(x, u, params)
153
- elif self._eq_type == "Non-statio PDE":
154
- ans = self.equation(t, x, u, params)
155
- else:
156
- raise NotImplementedError("the equation type is not handled.")
157
-
158
- return ans
108
+ evaluation = self.equation(inputs, u, params)
109
+ if len(evaluation.shape) == 0:
110
+ raise ValueError(
111
+ "The output of dynamic loss must be vectorial, "
112
+ "i.e. of shape (d,) with d >= 1"
113
+ )
114
+ return evaluation
159
115
 
160
116
  @abc.abstractmethod
161
117
  def equation(self, *args, **kwargs):
@@ -191,7 +147,7 @@ class ODE(DynamicLoss):
191
147
 
192
148
  _eq_type: ClassVar[str] = "ODE"
193
149
 
194
- @partial(_decorator_heteregeneous_params, eq_type="ODE")
150
+ @partial(_decorator_heteregeneous_params)
195
151
  def evaluate(
196
152
  self,
197
153
  t: Float[Array, "1"],
@@ -199,7 +155,7 @@ class ODE(DynamicLoss):
199
155
  params: Params | ParamsDict,
200
156
  ) -> float:
201
157
  """Here we call DynamicLoss._evaluate with x=None"""
202
- return self._evaluate(t, None, u, params)
158
+ return self._evaluate(t, u, params)
203
159
 
204
160
  @abc.abstractmethod
205
161
  def equation(
@@ -260,12 +216,12 @@ class PDEStatio(DynamicLoss):
260
216
 
261
217
  _eq_type: ClassVar[str] = "Statio PDE"
262
218
 
263
- @partial(_decorator_heteregeneous_params, eq_type="Statio PDE")
219
+ @partial(_decorator_heteregeneous_params)
264
220
  def evaluate(
265
221
  self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
266
222
  ) -> float:
267
223
  """Here we call the DynamicLoss._evaluate with t=None"""
268
- return self._evaluate(None, x, u, params)
224
+ return self._evaluate(x, u, params)
269
225
 
270
226
  @abc.abstractmethod
271
227
  def equation(
@@ -325,23 +281,21 @@ class PDENonStatio(DynamicLoss):
325
281
 
326
282
  _eq_type: ClassVar[str] = "Non-statio PDE"
327
283
 
328
- @partial(_decorator_heteregeneous_params, eq_type="Non-statio PDE")
284
+ @partial(_decorator_heteregeneous_params)
329
285
  def evaluate(
330
286
  self,
331
- t: Float[Array, "1"],
332
- x: Float[Array, "dim"],
287
+ t_x: Float[Array, "1 + dim"],
333
288
  u: eqx.Module,
334
289
  params: Params | ParamsDict,
335
290
  ) -> float:
336
291
  """Here we call the DynamicLoss._evaluate with full arguments"""
337
- ans = self._evaluate(t, x, u, params)
292
+ ans = self._evaluate(t_x, u, params)
338
293
  return ans
339
294
 
340
295
  @abc.abstractmethod
341
296
  def equation(
342
297
  self,
343
- t: Float[Array, "1"],
344
- x: Float[Array, "dim"],
298
+ t_x: Float[Array, "1 + dim"],
345
299
  u: eqx.Module,
346
300
  params: Params | ParamsDict,
347
301
  ) -> float:
@@ -353,10 +307,8 @@ class PDENonStatio(DynamicLoss):
353
307
 
354
308
  Parameters
355
309
  ----------
356
- t : Float[Array, "1"]
357
- A 1-dimensional jnp.array representing the time point.
358
- x : Float[Array, "d"]
359
- A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
310
+ t_x : Float[Array, "1 + dim"]
311
+ A jnp array containing the concatenation of a time point and a point in $\Omega$
360
312
  u : eqx.Module
361
313
  The neural network.
362
314
  params : Params | ParamsDict
jinns/loss/_LossODE.py CHANGED
@@ -209,7 +209,7 @@ class LossODE(_LossODEAbstract):
209
209
  mse_dyn_loss = dynamic_loss_apply(
210
210
  self.dynamic_loss.evaluate,
211
211
  self.u,
212
- (temporal_batch,),
212
+ temporal_batch,
213
213
  _set_derivatives(params, self.derivative_keys.dyn_loss),
214
214
  self.vmap_in_axes + vmap_in_axes_params,
215
215
  self.loss_weights.dyn_loss,
@@ -227,7 +227,7 @@ class LossODE(_LossODEAbstract):
227
227
  else:
228
228
  v_u = vmap(self.u, (None,) + vmap_in_axes_params)
229
229
  t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
230
- t0 = jnp.array(t0)
230
+ t0 = jnp.array([t0])
231
231
  u0 = jnp.array(u0)
232
232
  mse_initial_condition = jnp.mean(
233
233
  self.loss_weights.initial_condition
@@ -522,7 +522,7 @@ class SystemLossODE(eqx.Module):
522
522
  return dynamic_loss_apply(
523
523
  dyn_loss.evaluate,
524
524
  self.u_dict,
525
- (temporal_batch,),
525
+ temporal_batch,
526
526
  _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
527
527
  vmap_in_axes_t + vmap_in_axes_params,
528
528
  loss_weight,
jinns/loss/_LossPDE.py CHANGED
@@ -22,9 +22,7 @@ from jinns.loss._loss_utils import (
22
22
  initial_condition_apply,
23
23
  constraints_system_loss_apply,
24
24
  )
25
- from jinns.data._DataGenerators import (
26
- append_obs_batch,
27
- )
25
+ from jinns.data._DataGenerators import append_obs_batch
28
26
  from jinns.parameters._params import (
29
27
  _get_vmap_in_axes_params,
30
28
  _update_eq_params_dict,
@@ -370,8 +368,8 @@ class LossPDEStatio(_LossPDEAbstract):
370
368
 
371
369
  def _get_dynamic_loss_batch(
372
370
  self, batch: PDEStatioBatch
373
- ) -> tuple[Float[Array, "batch_size dimension"]]:
374
- return (batch.inside_batch,)
371
+ ) -> Float[Array, "batch_size dimension"]:
372
+ return batch.domain_batch
375
373
 
376
374
  def _get_normalization_loss_batch(
377
375
  self, _
@@ -432,7 +430,7 @@ class LossPDEStatio(_LossPDEAbstract):
432
430
  self.u,
433
431
  self._get_normalization_loss_batch(batch),
434
432
  _set_derivatives(params, self.derivative_keys.norm_loss),
435
- self.vmap_in_axes + vmap_in_axes_params,
433
+ vmap_in_axes_params,
436
434
  self.norm_int_length,
437
435
  self.loss_weights.norm_loss,
438
436
  )
@@ -575,6 +573,9 @@ class LossPDENonStatio(LossPDEStatio):
575
573
  kw_only=True, default=None, static=True
576
574
  )
577
575
 
576
+ _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
577
+ _max_norm_time_slices: Int = eqx.field(init=False, static=True)
578
+
578
579
  def __post_init__(self, params=None):
579
580
  """
580
581
  Note that neither __init__ or __post_init__ are called when udating a
@@ -585,7 +586,7 @@ class LossPDENonStatio(LossPDEStatio):
585
586
  ) # because __init__ or __post_init__ of Base
586
587
  # class is not automatically called
587
588
 
588
- self.vmap_in_axes = (0, 0) # for t and x
589
+ self.vmap_in_axes = (0,) # for t_x
589
590
 
590
591
  if self.initial_condition_fun is None:
591
592
  warnings.warn(
@@ -593,28 +594,28 @@ class LossPDENonStatio(LossPDEStatio):
593
594
  "case (e.g by. hardcoding it into the PINN output)."
594
595
  )
595
596
 
597
+ # witht the variables below we avoid memory overflow since a cartesian
598
+ # product is taken
599
+ self._max_norm_time_slices = 100
600
+ self._max_norm_samples_omega = 1000
601
+
596
602
  def _get_dynamic_loss_batch(
597
603
  self, batch: PDENonStatioBatch
598
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
599
- times_batch = batch.times_x_inside_batch[:, 0:1]
600
- omega_batch = batch.times_x_inside_batch[:, 1:]
601
- return (times_batch, omega_batch)
604
+ ) -> Float[Array, "batch_size 1+dimension"]:
605
+ return batch.domain_batch
602
606
 
603
607
  def _get_normalization_loss_batch(
604
608
  self, batch: PDENonStatioBatch
605
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
609
+ ) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
606
610
  return (
607
- batch.times_x_inside_batch[:, 0:1],
608
- self.norm_samples,
611
+ batch.domain_batch[: self._max_norm_time_slices, 0:1],
612
+ self.norm_samples[: self._max_norm_samples_omega],
609
613
  )
610
614
 
611
615
  def _get_observations_loss_batch(
612
616
  self, batch: PDENonStatioBatch
613
617
  ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
614
- return (
615
- batch.obs_batch_dict["pinn_in"][:, 0:1],
616
- batch.obs_batch_dict["pinn_in"][:, 1:],
617
- )
618
+ return (batch.obs_batch_dict["pinn_in"],)
618
619
 
619
620
  def __call__(self, *args, **kwargs):
620
621
  return self.evaluate(*args, **kwargs)
@@ -637,7 +638,7 @@ class LossPDENonStatio(LossPDEStatio):
637
638
  of parameters (eg. for metamodeling) and an optional additional batch of observed
638
639
  inputs/outputs/parameters
639
640
  """
640
- omega_batch = batch.times_x_inside_batch[:, 1:]
641
+ omega_batch = batch.initial_batch
641
642
 
642
643
  # Retrieve the optional eq_params_batch
643
644
  # and update eq_params with the latter
@@ -660,7 +661,6 @@ class LossPDENonStatio(LossPDEStatio):
660
661
  _set_derivatives(params, self.derivative_keys.initial_condition),
661
662
  (0,) + vmap_in_axes_params,
662
663
  self.initial_condition_fun,
663
- omega_batch.shape[0],
664
664
  self.loss_weights.initial_condition,
665
665
  )
666
666
  else:
@@ -1016,19 +1016,7 @@ class SystemLossPDE(eqx.Module):
1016
1016
  if self.u_dict.keys() != params_dict.nn_params.keys():
1017
1017
  raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1018
1018
 
1019
- if isinstance(batch, PDEStatioBatch):
1020
- omega_batch, _ = batch.inside_batch, batch.border_batch
1021
- vmap_in_axes_x_or_x_t = (0,)
1022
-
1023
- batches = (omega_batch,)
1024
- elif isinstance(batch, PDENonStatioBatch):
1025
- times_batch = batch.times_x_inside_batch[:, 0:1]
1026
- omega_batch = batch.times_x_inside_batch[:, 1:]
1027
-
1028
- batches = (omega_batch, times_batch)
1029
- vmap_in_axes_x_or_x_t = (0, 0)
1030
- else:
1031
- raise ValueError("Wrong type of batch")
1019
+ vmap_in_axes = (0,)
1032
1020
 
1033
1021
  # Retrieve the optional eq_params_batch
1034
1022
  # and update eq_params with the latter
@@ -1036,7 +1024,6 @@ class SystemLossPDE(eqx.Module):
1036
1024
  if batch.param_batch_dict is not None:
1037
1025
  eq_params_batch_dict = batch.param_batch_dict
1038
1026
 
1039
- # TODO
1040
1027
  # feed the eq_params with the batch
1041
1028
  for k in eq_params_batch_dict.keys():
1042
1029
  params_dict.eq_params[k] = eq_params_batch_dict[k]
@@ -1050,9 +1037,13 @@ class SystemLossPDE(eqx.Module):
1050
1037
  return dynamic_loss_apply(
1051
1038
  dyn_loss.evaluate,
1052
1039
  self.u_dict,
1053
- batches,
1040
+ (
1041
+ batch.domain_batch
1042
+ if isinstance(batch, PDEStatioBatch)
1043
+ else batch.domain_batch
1044
+ ),
1054
1045
  _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1055
- vmap_in_axes_x_or_x_t + vmap_in_axes_params,
1046
+ vmap_in_axes + vmap_in_axes_params,
1056
1047
  loss_weight,
1057
1048
  u_type=type(list(self.u_dict.values())[0]),
1058
1049
  )
jinns/loss/__init__.py CHANGED
@@ -3,7 +3,7 @@ from ._LossODE import LossODE, SystemLossODE
3
3
  from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
4
4
  from ._DynamicLoss import (
5
5
  GeneralizedLotkaVolterra,
6
- BurgerEquation,
6
+ BurgersEquation,
7
7
  FPENonStatioLoss2D,
8
8
  OU_FPENonStatioLoss2D,
9
9
  FisherKPP,
@@ -19,9 +19,10 @@ from ._loss_weights import (
19
19
  )
20
20
 
21
21
  from ._operators import (
22
- _div_fwd,
23
- _div_rev,
24
- _laplacian_fwd,
25
- _laplacian_rev,
26
- _vectorial_laplacian,
22
+ divergence_fwd,
23
+ divergence_rev,
24
+ laplacian_fwd,
25
+ laplacian_rev,
26
+ vectorial_laplacian_fwd,
27
+ vectorial_laplacian_rev,
27
28
  )