jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,53 +21,22 @@ else:
21
21
  from equinox import AbstractClassVar
22
22
 
23
23
 
24
- def _decorator_heteregeneous_params(evaluate, eq_type):
24
+ def _decorator_heteregeneous_params(evaluate):
25
25
 
26
- def wrapper_ode(*args):
27
- self, t, u, params = args
26
+ def wrapper(*args):
27
+ self, inputs, u, params = args
28
28
  _params = eqx.tree_at(
29
29
  lambda p: p.eq_params,
30
30
  params,
31
31
  self._eval_heterogeneous_parameters(
32
- t, None, u, params, self.eq_params_heterogeneity
32
+ inputs, u, params, self.eq_params_heterogeneity
33
33
  ),
34
34
  )
35
35
  new_args = args[:-1] + (_params,)
36
36
  res = evaluate(*new_args)
37
37
  return res
38
38
 
39
- def wrapper_pde_statio(*args):
40
- self, x, u, params = args
41
- _params = eqx.tree_at(
42
- lambda p: p.eq_params,
43
- params,
44
- self._eval_heterogeneous_parameters(
45
- None, x, u, params, self.eq_params_heterogeneity
46
- ),
47
- )
48
- new_args = args[:-1] + (_params,)
49
- res = evaluate(*new_args)
50
- return res
51
-
52
- def wrapper_pde_non_statio(*args):
53
- self, t, x, u, params = args
54
- _params = eqx.tree_at(
55
- lambda p: p.eq_params,
56
- params,
57
- self._eval_heterogeneous_parameters(
58
- t, x, u, params, self.eq_params_heterogeneity
59
- ),
60
- )
61
- new_args = args[:-1] + (_params,)
62
- res = evaluate(*new_args)
63
- return res
64
-
65
- if eq_type == "ODE":
66
- return wrapper_ode
67
- elif eq_type == "Statio PDE":
68
- return wrapper_pde_statio
69
- elif eq_type == "Non-statio PDE":
70
- return wrapper_pde_non_statio
39
+ return wrapper
71
40
 
72
41
 
73
42
  class DynamicLoss(eqx.Module):
@@ -110,8 +79,7 @@ class DynamicLoss(eqx.Module):
110
79
 
111
80
  def _eval_heterogeneous_parameters(
112
81
  self,
113
- t: Float[Array, "1"],
114
- x: Float[Array, "dim"],
82
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
115
83
  u: eqx.Module,
116
84
  params: Params | ParamsDict,
117
85
  eq_params_heterogeneity: Dict[str, Callable | None] = None,
@@ -124,14 +92,7 @@ class DynamicLoss(eqx.Module):
124
92
  if eq_params_heterogeneity[k] is None:
125
93
  eq_params_[k] = p
126
94
  else:
127
- # heterogeneity encoded through a function whose
128
- # signature will vary according to _eq_type
129
- if self._eq_type == "ODE":
130
- eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
131
- elif self._eq_type == "Statio PDE":
132
- eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
133
- elif self._eq_type == "Non-statio PDE":
134
- eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
95
+ eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params)
135
96
  except KeyError:
136
97
  # we authorize missing eq_params_heterogeneity key
137
98
  # if its heterogeneity is None anyway
@@ -140,22 +101,17 @@ class DynamicLoss(eqx.Module):
140
101
 
141
102
  def _evaluate(
142
103
  self,
143
- t: Float[Array, "1"],
144
- x: Float[Array, "dim"],
104
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
145
105
  u: eqx.Module,
146
106
  params: Params | ParamsDict,
147
107
  ) -> float:
148
- # Here we handle the various possible signature
149
- if self._eq_type == "ODE":
150
- ans = self.equation(t, u, params)
151
- elif self._eq_type == "Statio PDE":
152
- ans = self.equation(x, u, params)
153
- elif self._eq_type == "Non-statio PDE":
154
- ans = self.equation(t, x, u, params)
155
- else:
156
- raise NotImplementedError("the equation type is not handled.")
157
-
158
- return ans
108
+ evaluation = self.equation(inputs, u, params)
109
+ if len(evaluation.shape) == 0:
110
+ raise ValueError(
111
+ "The output of dynamic loss must be vectorial, "
112
+ "i.e. of shape (d,) with d >= 1"
113
+ )
114
+ return evaluation
159
115
 
160
116
  @abc.abstractmethod
161
117
  def equation(self, *args, **kwargs):
@@ -191,7 +147,7 @@ class ODE(DynamicLoss):
191
147
 
192
148
  _eq_type: ClassVar[str] = "ODE"
193
149
 
194
- @partial(_decorator_heteregeneous_params, eq_type="ODE")
150
+ @partial(_decorator_heteregeneous_params)
195
151
  def evaluate(
196
152
  self,
197
153
  t: Float[Array, "1"],
@@ -199,7 +155,7 @@ class ODE(DynamicLoss):
199
155
  params: Params | ParamsDict,
200
156
  ) -> float:
201
157
  """Here we call DynamicLoss._evaluate with x=None"""
202
- return self._evaluate(t, None, u, params)
158
+ return self._evaluate(t, u, params)
203
159
 
204
160
  @abc.abstractmethod
205
161
  def equation(
@@ -260,12 +216,12 @@ class PDEStatio(DynamicLoss):
260
216
 
261
217
  _eq_type: ClassVar[str] = "Statio PDE"
262
218
 
263
- @partial(_decorator_heteregeneous_params, eq_type="Statio PDE")
219
+ @partial(_decorator_heteregeneous_params)
264
220
  def evaluate(
265
221
  self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
266
222
  ) -> float:
267
223
  """Here we call the DynamicLoss._evaluate with t=None"""
268
- return self._evaluate(None, x, u, params)
224
+ return self._evaluate(x, u, params)
269
225
 
270
226
  @abc.abstractmethod
271
227
  def equation(
@@ -325,23 +281,21 @@ class PDENonStatio(DynamicLoss):
325
281
 
326
282
  _eq_type: ClassVar[str] = "Non-statio PDE"
327
283
 
328
- @partial(_decorator_heteregeneous_params, eq_type="Non-statio PDE")
284
+ @partial(_decorator_heteregeneous_params)
329
285
  def evaluate(
330
286
  self,
331
- t: Float[Array, "1"],
332
- x: Float[Array, "dim"],
287
+ t_x: Float[Array, "1 + dim"],
333
288
  u: eqx.Module,
334
289
  params: Params | ParamsDict,
335
290
  ) -> float:
336
291
  """Here we call the DynamicLoss._evaluate with full arguments"""
337
- ans = self._evaluate(t, x, u, params)
292
+ ans = self._evaluate(t_x, u, params)
338
293
  return ans
339
294
 
340
295
  @abc.abstractmethod
341
296
  def equation(
342
297
  self,
343
- t: Float[Array, "1"],
344
- x: Float[Array, "dim"],
298
+ t_x: Float[Array, "1 + dim"],
345
299
  u: eqx.Module,
346
300
  params: Params | ParamsDict,
347
301
  ) -> float:
@@ -353,10 +307,8 @@ class PDENonStatio(DynamicLoss):
353
307
 
354
308
  Parameters
355
309
  ----------
356
- t : Float[Array, "1"]
357
- A 1-dimensional jnp.array representing the time point.
358
- x : Float[Array, "d"]
359
- A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
310
+ t_x : Float[Array, "1 + dim"]
311
+ A jnp array containing the concatenation of a time point and a point in $\Omega$
360
312
  u : eqx.Module
361
313
  The neural network.
362
314
  params : Params | ParamsDict
jinns/loss/_LossODE.py CHANGED
@@ -28,7 +28,7 @@ from jinns.parameters._params import (
28
28
  from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
29
29
  from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
30
30
  from jinns.loss._DynamicLossAbstract import ODE
31
- from jinns.utils._pinn import PINN
31
+ from jinns.nn._pinn import PINN
32
32
 
33
33
  if TYPE_CHECKING:
34
34
  from jinns.utils._types import *
@@ -209,7 +209,7 @@ class LossODE(_LossODEAbstract):
209
209
  mse_dyn_loss = dynamic_loss_apply(
210
210
  self.dynamic_loss.evaluate,
211
211
  self.u,
212
- (temporal_batch,),
212
+ temporal_batch,
213
213
  _set_derivatives(params, self.derivative_keys.dyn_loss),
214
214
  self.vmap_in_axes + vmap_in_axes_params,
215
215
  self.loss_weights.dyn_loss,
@@ -227,7 +227,7 @@ class LossODE(_LossODEAbstract):
227
227
  else:
228
228
  v_u = vmap(self.u, (None,) + vmap_in_axes_params)
229
229
  t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
230
- t0 = jnp.array(t0)
230
+ t0 = jnp.array([t0])
231
231
  u0 = jnp.array(u0)
232
232
  mse_initial_condition = jnp.mean(
233
233
  self.loss_weights.initial_condition
@@ -522,7 +522,7 @@ class SystemLossODE(eqx.Module):
522
522
  return dynamic_loss_apply(
523
523
  dyn_loss.evaluate,
524
524
  self.u_dict,
525
- (temporal_batch,),
525
+ temporal_batch,
526
526
  _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
527
527
  vmap_in_axes_t + vmap_in_axes_params,
528
528
  loss_weight,
jinns/loss/_LossPDE.py CHANGED
@@ -22,9 +22,7 @@ from jinns.loss._loss_utils import (
22
22
  initial_condition_apply,
23
23
  constraints_system_loss_apply,
24
24
  )
25
- from jinns.data._DataGenerators import (
26
- append_obs_batch,
27
- )
25
+ from jinns.data._DataGenerators import append_obs_batch
28
26
  from jinns.parameters._params import (
29
27
  _get_vmap_in_axes_params,
30
28
  _update_eq_params_dict,
@@ -40,8 +38,8 @@ from jinns.loss._loss_weights import (
40
38
  LossWeightsPDEDict,
41
39
  )
42
40
  from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
43
- from jinns.utils._pinn import PINN
44
- from jinns.utils._spinn import SPINN
41
+ from jinns.nn._pinn import PINN
42
+ from jinns.nn._spinn import SPINN
45
43
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
46
44
 
47
45
 
@@ -94,12 +92,16 @@ class _LossPDEAbstract(eqx.Module):
94
92
  Note that it must be a slice and not an integer
95
93
  (but a preprocessing of the user provided argument takes care of it)
96
94
  norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
97
- Fixed sample point in the space over which to compute the
95
+ Monte-Carlo sample points for computing the
98
96
  normalization constant. Default is None.
99
- norm_int_length : float, default=None
100
- A float. Must be provided if `norm_samples` is provided. The domain area
101
- (or interval length in 1D) upon which we perform the numerical
102
- integration. Default None
97
+ norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
98
+ The importance sampling weights for Monte-Carlo integration of the
99
+ normalization constant. Must be provided if `norm_samples` is provided.
100
+ `norm_weights` should have the same leading dimension as
101
+ `norm_samples`.
102
+ Alternatively, the user can pass a float or an integer.
103
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
104
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
103
105
  obs_slice : slice, default=None
104
106
  slice object specifying the begininning/ending of the PINN output
105
107
  that is observed (this is then useful for multidim PINN). Default is None.
@@ -129,7 +131,9 @@ class _LossPDEAbstract(eqx.Module):
129
131
  norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
130
132
  kw_only=True, default=None
131
133
  )
132
- norm_int_length: float | None = eqx.field(kw_only=True, default=None)
134
+ norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
135
+ kw_only=True, default=None
136
+ )
133
137
  obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
134
138
 
135
139
  params: InitVar[Params] = eqx.field(kw_only=True, default=None)
@@ -253,8 +257,25 @@ class _LossPDEAbstract(eqx.Module):
253
257
  if not isinstance(self.omega_boundary_dim, slice):
254
258
  raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
255
259
 
256
- if self.norm_samples is not None and self.norm_int_length is None:
257
- raise ValueError("self.norm_samples and norm_int_length must be provided")
260
+ if self.norm_samples is not None:
261
+ if self.norm_weights is None:
262
+ raise ValueError(
263
+ "`norm_weights` must be provided when `norm_samples` is used!"
264
+ )
265
+ try:
266
+ assert self.norm_weights.shape[0] == self.norm_samples.shape[0]
267
+ except (AssertionError, AttributeError):
268
+ if isinstance(self.norm_weights, (int, float)):
269
+ self.norm_weights = jnp.array(
270
+ [self.norm_weights], dtype=jax.dtypes.canonicalize_dtype(float)
271
+ )
272
+ else:
273
+ raise ValueError(
274
+ "`norm_weights` should have the same leading dimension"
275
+ " as `norm_samples`,"
276
+ f" got shape {self.norm_weights.shape} and"
277
+ f" shape {self.norm_samples.shape}."
278
+ )
258
279
 
259
280
  @abc.abstractmethod
260
281
  def evaluate(
@@ -326,12 +347,16 @@ class LossPDEStatio(_LossPDEAbstract):
326
347
  Note that it must be a slice and not an integer
327
348
  (but a preprocessing of the user provided argument takes care of it)
328
349
  norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
329
- Fixed sample point in the space over which to compute the
350
+ Monte-Carlo sample points for computing the
330
351
  normalization constant. Default is None.
331
- norm_int_length : float, default=None
332
- A float. Must be provided if `norm_samples` is provided. The domain area
333
- (or interval length in 1D) upon which we perform the numerical
334
- integration. Default None
352
+ norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
353
+ The importance sampling weights for Monte-Carlo integration of the
354
+ normalization constant. Must be provided if `norm_samples` is provided.
355
+ `norm_weights` should have the same leading dimension as
356
+ `norm_samples`.
357
+ Alternatively, the user can pass a float or an integer.
358
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
359
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
335
360
  obs_slice : slice, default=None
336
361
  slice object specifying the begininning/ending of the PINN output
337
362
  that is observed (this is then useful for multidim PINN). Default is None.
@@ -370,8 +395,8 @@ class LossPDEStatio(_LossPDEAbstract):
370
395
 
371
396
  def _get_dynamic_loss_batch(
372
397
  self, batch: PDEStatioBatch
373
- ) -> tuple[Float[Array, "batch_size dimension"]]:
374
- return (batch.inside_batch,)
398
+ ) -> Float[Array, "batch_size dimension"]:
399
+ return batch.domain_batch
375
400
 
376
401
  def _get_normalization_loss_batch(
377
402
  self, _
@@ -432,8 +457,8 @@ class LossPDEStatio(_LossPDEAbstract):
432
457
  self.u,
433
458
  self._get_normalization_loss_batch(batch),
434
459
  _set_derivatives(params, self.derivative_keys.norm_loss),
435
- self.vmap_in_axes + vmap_in_axes_params,
436
- self.norm_int_length,
460
+ vmap_in_axes_params,
461
+ self.norm_weights,
437
462
  self.loss_weights.norm_loss,
438
463
  )
439
464
  else:
@@ -551,12 +576,16 @@ class LossPDENonStatio(LossPDEStatio):
551
576
  Note that it must be a slice and not an integer
552
577
  (but a preprocessing of the user provided argument takes care of it)
553
578
  norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
554
- Fixed sample point in the space over which to compute the
579
+ Monte-Carlo sample points for computing the
555
580
  normalization constant. Default is None.
556
- norm_int_length : float, default=None
557
- A float. Must be provided if `norm_samples` is provided. The domain area
558
- (or interval length in 1D) upon which we perform the numerical
559
- integration. Default None
581
+ norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
582
+ The importance sampling weights for Monte-Carlo integration of the
583
+ normalization constant. Must be provided if `norm_samples` is provided.
584
+ `norm_weights` should have the same leading dimension as
585
+ `norm_samples`.
586
+ Alternatively, the user can pass a float or an integer.
587
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
588
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
560
589
  obs_slice : slice, default=None
561
590
  slice object specifying the begininning/ending of the PINN output
562
591
  that is observed (this is then useful for multidim PINN). Default is None.
@@ -575,6 +604,9 @@ class LossPDENonStatio(LossPDEStatio):
575
604
  kw_only=True, default=None, static=True
576
605
  )
577
606
 
607
+ _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
608
+ _max_norm_time_slices: Int = eqx.field(init=False, static=True)
609
+
578
610
  def __post_init__(self, params=None):
579
611
  """
580
612
  Note that neither __init__ or __post_init__ are called when udating a
@@ -585,7 +617,7 @@ class LossPDENonStatio(LossPDEStatio):
585
617
  ) # because __init__ or __post_init__ of Base
586
618
  # class is not automatically called
587
619
 
588
- self.vmap_in_axes = (0, 0) # for t and x
620
+ self.vmap_in_axes = (0,) # for t_x
589
621
 
590
622
  if self.initial_condition_fun is None:
591
623
  warnings.warn(
@@ -593,28 +625,28 @@ class LossPDENonStatio(LossPDEStatio):
593
625
  "case (e.g by. hardcoding it into the PINN output)."
594
626
  )
595
627
 
628
+ # witht the variables below we avoid memory overflow since a cartesian
629
+ # product is taken
630
+ self._max_norm_time_slices = 100
631
+ self._max_norm_samples_omega = 1000
632
+
596
633
  def _get_dynamic_loss_batch(
597
634
  self, batch: PDENonStatioBatch
598
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
599
- times_batch = batch.times_x_inside_batch[:, 0:1]
600
- omega_batch = batch.times_x_inside_batch[:, 1:]
601
- return (times_batch, omega_batch)
635
+ ) -> Float[Array, "batch_size 1+dimension"]:
636
+ return batch.domain_batch
602
637
 
603
638
  def _get_normalization_loss_batch(
604
639
  self, batch: PDENonStatioBatch
605
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
640
+ ) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
606
641
  return (
607
- batch.times_x_inside_batch[:, 0:1],
608
- self.norm_samples,
642
+ batch.domain_batch[: self._max_norm_time_slices, 0:1],
643
+ self.norm_samples[: self._max_norm_samples_omega],
609
644
  )
610
645
 
611
646
  def _get_observations_loss_batch(
612
647
  self, batch: PDENonStatioBatch
613
648
  ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
614
- return (
615
- batch.obs_batch_dict["pinn_in"][:, 0:1],
616
- batch.obs_batch_dict["pinn_in"][:, 1:],
617
- )
649
+ return (batch.obs_batch_dict["pinn_in"],)
618
650
 
619
651
  def __call__(self, *args, **kwargs):
620
652
  return self.evaluate(*args, **kwargs)
@@ -637,7 +669,7 @@ class LossPDENonStatio(LossPDEStatio):
637
669
  of parameters (eg. for metamodeling) and an optional additional batch of observed
638
670
  inputs/outputs/parameters
639
671
  """
640
- omega_batch = batch.times_x_inside_batch[:, 1:]
672
+ omega_batch = batch.initial_batch
641
673
 
642
674
  # Retrieve the optional eq_params_batch
643
675
  # and update eq_params with the latter
@@ -660,7 +692,6 @@ class LossPDENonStatio(LossPDEStatio):
660
692
  _set_derivatives(params, self.derivative_keys.initial_condition),
661
693
  (0,) + vmap_in_axes_params,
662
694
  self.initial_condition_fun,
663
- omega_batch.shape[0],
664
695
  self.loss_weights.initial_condition,
665
696
  )
666
697
  else:
@@ -730,15 +761,21 @@ class SystemLossPDE(eqx.Module):
730
761
  (default) then no temporal boundary condition is applied
731
762
  Must share the keys of `u_dict`
732
763
  norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
733
- A dict of fixed sample point in the space over which to compute the
734
- normalization constant. Default is None
735
- Must share the keys of `u_dict`
736
- norm_int_length_dict : Dict[str, float | None] | None, default=None
737
- A dict of Float. The domain area
738
- (or interval length in 1D) upon which we perform the numerical
739
- integration for each element of u_dict.
740
- Default is None
764
+ A dict of Monte-Carlo sample points for computing the
765
+ normalization constant. Default is None.
741
766
  Must share the keys of `u_dict`
767
+ norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
768
+ A dict of jnp.array with the same keys as `u_dict`. The importance
769
+ sampling weights for Monte-Carlo integration of the
770
+ normalization constant for each element of u_dict. Must be provided if
771
+ `norm_samples_dict` is provided.
772
+ `norm_weights_dict[key]` should have the same leading dimension as
773
+ `norm_samples_dict[key]` for each `key`.
774
+ Alternatively, the user can pass a float or an integer.
775
+ For each key, an array of similar shape to `norm_samples_dict[key]`
776
+ or shape `(1,)` is expected. These corresponds to the weights $w_k =
777
+ \frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
778
+ the Monte-Carlo samples. Default is None
742
779
  obs_slice_dict : Dict[str, slice | None] | None, default=None
743
780
  dict of obs_slice, with keys from `u_dict` to designate the
744
781
  output(s) channels that are forced to observed values, for each
@@ -774,9 +811,9 @@ class SystemLossPDE(eqx.Module):
774
811
  norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
775
812
  eqx.field(kw_only=True, default=None)
776
813
  )
777
- norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
778
- kw_only=True, default=None
779
- )
814
+ norm_weights_dict: (
815
+ Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
816
+ ) = eqx.field(kw_only=True, default=None)
780
817
  obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
781
818
  kw_only=True, default=None, static=True
782
819
  )
@@ -819,8 +856,8 @@ class SystemLossPDE(eqx.Module):
819
856
  self.initial_condition_fun_dict = self.u_dict_with_none
820
857
  if self.norm_samples_dict is None:
821
858
  self.norm_samples_dict = self.u_dict_with_none
822
- if self.norm_int_length_dict is None:
823
- self.norm_int_length_dict = self.u_dict_with_none
859
+ if self.norm_weights_dict is None:
860
+ self.norm_weights_dict = self.u_dict_with_none
824
861
  if self.obs_slice_dict is None:
825
862
  self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
826
863
  if self.u_dict.keys() != self.obs_slice_dict.keys():
@@ -861,7 +898,7 @@ class SystemLossPDE(eqx.Module):
861
898
  or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
862
899
  or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
863
900
  or self.u_dict.keys() != self.norm_samples_dict.keys()
864
- or self.u_dict.keys() != self.norm_int_length_dict.keys()
901
+ or self.u_dict.keys() != self.norm_weights_dict.keys()
865
902
  ):
866
903
  raise ValueError("All the dicts concerning the PINNs should have same keys")
867
904
 
@@ -890,7 +927,7 @@ class SystemLossPDE(eqx.Module):
890
927
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
891
928
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
892
929
  norm_samples=self.norm_samples_dict[i],
893
- norm_int_length=self.norm_int_length_dict[i],
930
+ norm_weights=self.norm_weights_dict[i],
894
931
  obs_slice=self.obs_slice_dict[i],
895
932
  )
896
933
  elif self.u_dict[i].eq_type == "nonstatio_PDE":
@@ -911,7 +948,7 @@ class SystemLossPDE(eqx.Module):
911
948
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
912
949
  initial_condition_fun=self.initial_condition_fun_dict[i],
913
950
  norm_samples=self.norm_samples_dict[i],
914
- norm_int_length=self.norm_int_length_dict[i],
951
+ norm_weights=self.norm_weights_dict[i],
915
952
  obs_slice=self.obs_slice_dict[i],
916
953
  )
917
954
  else:
@@ -1016,19 +1053,7 @@ class SystemLossPDE(eqx.Module):
1016
1053
  if self.u_dict.keys() != params_dict.nn_params.keys():
1017
1054
  raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1018
1055
 
1019
- if isinstance(batch, PDEStatioBatch):
1020
- omega_batch, _ = batch.inside_batch, batch.border_batch
1021
- vmap_in_axes_x_or_x_t = (0,)
1022
-
1023
- batches = (omega_batch,)
1024
- elif isinstance(batch, PDENonStatioBatch):
1025
- times_batch = batch.times_x_inside_batch[:, 0:1]
1026
- omega_batch = batch.times_x_inside_batch[:, 1:]
1027
-
1028
- batches = (omega_batch, times_batch)
1029
- vmap_in_axes_x_or_x_t = (0, 0)
1030
- else:
1031
- raise ValueError("Wrong type of batch")
1056
+ vmap_in_axes = (0,)
1032
1057
 
1033
1058
  # Retrieve the optional eq_params_batch
1034
1059
  # and update eq_params with the latter
@@ -1036,7 +1061,6 @@ class SystemLossPDE(eqx.Module):
1036
1061
  if batch.param_batch_dict is not None:
1037
1062
  eq_params_batch_dict = batch.param_batch_dict
1038
1063
 
1039
- # TODO
1040
1064
  # feed the eq_params with the batch
1041
1065
  for k in eq_params_batch_dict.keys():
1042
1066
  params_dict.eq_params[k] = eq_params_batch_dict[k]
@@ -1050,11 +1074,15 @@ class SystemLossPDE(eqx.Module):
1050
1074
  return dynamic_loss_apply(
1051
1075
  dyn_loss.evaluate,
1052
1076
  self.u_dict,
1053
- batches,
1077
+ (
1078
+ batch.domain_batch
1079
+ if isinstance(batch, PDEStatioBatch)
1080
+ else batch.domain_batch
1081
+ ),
1054
1082
  _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1055
- vmap_in_axes_x_or_x_t + vmap_in_axes_params,
1083
+ vmap_in_axes + vmap_in_axes_params,
1056
1084
  loss_weight,
1057
- u_type=type(list(self.u_dict.values())[0]),
1085
+ u_type=list(self.u_dict.values())[0].__class__.__base__,
1058
1086
  )
1059
1087
 
1060
1088
  dyn_loss_mse_dict = jax.tree_util.tree_map(
jinns/loss/__init__.py CHANGED
@@ -3,7 +3,7 @@ from ._LossODE import LossODE, SystemLossODE
3
3
  from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
4
4
  from ._DynamicLoss import (
5
5
  GeneralizedLotkaVolterra,
6
- BurgerEquation,
6
+ BurgersEquation,
7
7
  FPENonStatioLoss2D,
8
8
  OU_FPENonStatioLoss2D,
9
9
  FisherKPP,
@@ -19,9 +19,10 @@ from ._loss_weights import (
19
19
  )
20
20
 
21
21
  from ._operators import (
22
- _div_fwd,
23
- _div_rev,
24
- _laplacian_fwd,
25
- _laplacian_rev,
26
- _vectorial_laplacian,
22
+ divergence_fwd,
23
+ divergence_rev,
24
+ laplacian_fwd,
25
+ laplacian_rev,
26
+ vectorial_laplacian_fwd,
27
+ vectorial_laplacian_rev,
27
28
  )