jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_LossPDE.py CHANGED
@@ -22,9 +22,7 @@ from jinns.loss._loss_utils import (
22
22
  initial_condition_apply,
23
23
  constraints_system_loss_apply,
24
24
  )
25
- from jinns.data._DataGenerators import (
26
- append_obs_batch,
27
- )
25
+ from jinns.data._DataGenerators import append_obs_batch
28
26
  from jinns.parameters._params import (
29
27
  _get_vmap_in_axes_params,
30
28
  _update_eq_params_dict,
@@ -103,6 +101,9 @@ class _LossPDEAbstract(eqx.Module):
103
101
  obs_slice : slice, default=None
104
102
  slice object specifying the begininning/ending of the PINN output
105
103
  that is observed (this is then useful for multidim PINN). Default is None.
104
+ params : InitVar[Params], default=None
105
+ The main Params object of the problem needed to instanciate the
106
+ DerivativeKeysODE if the latter is not specified.
106
107
  """
107
108
 
108
109
  # NOTE static=True only for leaf attributes that are not valid JAX types
@@ -129,18 +130,26 @@ class _LossPDEAbstract(eqx.Module):
129
130
  norm_int_length: float | None = eqx.field(kw_only=True, default=None)
130
131
  obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
131
132
 
132
- def __post_init__(self):
133
+ params: InitVar[Params] = eqx.field(kw_only=True, default=None)
134
+
135
+ def __post_init__(self, params=None):
133
136
  """
134
137
  Note that neither __init__ or __post_init__ are called when udating a
135
138
  Module with eqx.tree_at
136
139
  """
137
140
  if self.derivative_keys is None:
138
141
  # be default we only take gradient wrt nn_params
139
- self.derivative_keys = (
140
- DerivativeKeysPDENonStatio()
141
- if isinstance(self, LossPDENonStatio)
142
- else DerivativeKeysPDEStatio()
143
- )
142
+ try:
143
+ self.derivative_keys = (
144
+ DerivativeKeysPDENonStatio(params=params)
145
+ if isinstance(self, LossPDENonStatio)
146
+ else DerivativeKeysPDEStatio(params=params)
147
+ )
148
+ except ValueError as exc:
149
+ raise ValueError(
150
+ "Problem at self.derivative_keys initialization "
151
+ f"received {self.derivative_keys=} and {params=}"
152
+ ) from exc
144
153
 
145
154
  if self.loss_weights is None:
146
155
  self.loss_weights = (
@@ -324,6 +333,9 @@ class LossPDEStatio(_LossPDEAbstract):
324
333
  obs_slice : slice, default=None
325
334
  slice object specifying the begininning/ending of the PINN output
326
335
  that is observed (this is then useful for multidim PINN). Default is None.
336
+ params : InitVar[Params], default=None
337
+ The main Params object of the problem needed to instanciate the
338
+ DerivativeKeysODE if the latter is not specified.
327
339
 
328
340
 
329
341
  Raises
@@ -342,20 +354,22 @@ class LossPDEStatio(_LossPDEAbstract):
342
354
 
343
355
  vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
344
356
 
345
- def __post_init__(self):
357
+ def __post_init__(self, params=None):
346
358
  """
347
359
  Note that neither __init__ or __post_init__ are called when udating a
348
360
  Module with eqx.tree_at!
349
361
  """
350
- super().__post_init__() # because __init__ or __post_init__ of Base
362
+ super().__post_init__(
363
+ params=params
364
+ ) # because __init__ or __post_init__ of Base
351
365
  # class is not automatically called
352
366
 
353
367
  self.vmap_in_axes = (0,) # for x only here
354
368
 
355
369
  def _get_dynamic_loss_batch(
356
370
  self, batch: PDEStatioBatch
357
- ) -> tuple[Float[Array, "batch_size dimension"]]:
358
- return (batch.inside_batch,)
371
+ ) -> Float[Array, "batch_size dimension"]:
372
+ return batch.domain_batch
359
373
 
360
374
  def _get_normalization_loss_batch(
361
375
  self, _
@@ -416,7 +430,7 @@ class LossPDEStatio(_LossPDEAbstract):
416
430
  self.u,
417
431
  self._get_normalization_loss_batch(batch),
418
432
  _set_derivatives(params, self.derivative_keys.norm_loss),
419
- self.vmap_in_axes + vmap_in_axes_params,
433
+ vmap_in_axes_params,
420
434
  self.norm_int_length,
421
435
  self.loss_weights.norm_loss,
422
436
  )
@@ -547,6 +561,9 @@ class LossPDENonStatio(LossPDEStatio):
547
561
  initial_condition_fun : Callable, default=None
548
562
  A function representing the temporal initial condition. If None
549
563
  (default) then no initial condition is applied
564
+ params : InitVar[Params], default=None
565
+ The main Params object of the problem needed to instanciate the
566
+ DerivativeKeysODE if the latter is not specified.
550
567
 
551
568
  """
552
569
 
@@ -556,15 +573,20 @@ class LossPDENonStatio(LossPDEStatio):
556
573
  kw_only=True, default=None, static=True
557
574
  )
558
575
 
559
- def __post_init__(self):
576
+ _max_norm_samples_omega: Int = eqx.field(init=False, static=True)
577
+ _max_norm_time_slices: Int = eqx.field(init=False, static=True)
578
+
579
+ def __post_init__(self, params=None):
560
580
  """
561
581
  Note that neither __init__ or __post_init__ are called when udating a
562
582
  Module with eqx.tree_at!
563
583
  """
564
- super().__post_init__() # because __init__ or __post_init__ of Base
584
+ super().__post_init__(
585
+ params=params
586
+ ) # because __init__ or __post_init__ of Base
565
587
  # class is not automatically called
566
588
 
567
- self.vmap_in_axes = (0, 0) # for t and x
589
+ self.vmap_in_axes = (0,) # for t_x
568
590
 
569
591
  if self.initial_condition_fun is None:
570
592
  warnings.warn(
@@ -572,28 +594,28 @@ class LossPDENonStatio(LossPDEStatio):
572
594
  "case (e.g by. hardcoding it into the PINN output)."
573
595
  )
574
596
 
597
+ # witht the variables below we avoid memory overflow since a cartesian
598
+ # product is taken
599
+ self._max_norm_time_slices = 100
600
+ self._max_norm_samples_omega = 1000
601
+
575
602
  def _get_dynamic_loss_batch(
576
603
  self, batch: PDENonStatioBatch
577
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
578
- times_batch = batch.times_x_inside_batch[:, 0:1]
579
- omega_batch = batch.times_x_inside_batch[:, 1:]
580
- return (times_batch, omega_batch)
604
+ ) -> Float[Array, "batch_size 1+dimension"]:
605
+ return batch.domain_batch
581
606
 
582
607
  def _get_normalization_loss_batch(
583
608
  self, batch: PDENonStatioBatch
584
- ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "nb_norm_samples dimension"]]:
609
+ ) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
585
610
  return (
586
- batch.times_x_inside_batch[:, 0:1],
587
- self.norm_samples,
611
+ batch.domain_batch[: self._max_norm_time_slices, 0:1],
612
+ self.norm_samples[: self._max_norm_samples_omega],
588
613
  )
589
614
 
590
615
  def _get_observations_loss_batch(
591
616
  self, batch: PDENonStatioBatch
592
617
  ) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
593
- return (
594
- batch.obs_batch_dict["pinn_in"][:, 0:1],
595
- batch.obs_batch_dict["pinn_in"][:, 1:],
596
- )
618
+ return (batch.obs_batch_dict["pinn_in"],)
597
619
 
598
620
  def __call__(self, *args, **kwargs):
599
621
  return self.evaluate(*args, **kwargs)
@@ -616,8 +638,7 @@ class LossPDENonStatio(LossPDEStatio):
616
638
  of parameters (eg. for metamodeling) and an optional additional batch of observed
617
639
  inputs/outputs/parameters
618
640
  """
619
-
620
- omega_batch = batch.times_x_inside_batch[:, 1:]
641
+ omega_batch = batch.initial_batch
621
642
 
622
643
  # Retrieve the optional eq_params_batch
623
644
  # and update eq_params with the latter
@@ -640,7 +661,6 @@ class LossPDENonStatio(LossPDEStatio):
640
661
  _set_derivatives(params, self.derivative_keys.initial_condition),
641
662
  (0,) + vmap_in_axes_params,
642
663
  self.initial_condition_fun,
643
- omega_batch.shape[0],
644
664
  self.loss_weights.initial_condition,
645
665
  )
646
666
  else:
@@ -725,6 +745,9 @@ class SystemLossPDE(eqx.Module):
725
745
  PINNs. Default is None. But if a value is given, all the entries of
726
746
  `u_dict` must be represented here with default value `jnp.s_[...]`
727
747
  if no particular slice is to be given
748
+ params : InitVar[ParamsDict], default=None
749
+ The main Params object of the problem needed to instanciate the
750
+ DerivativeKeysODE if the latter is not specified.
728
751
 
729
752
  """
730
753
 
@@ -763,22 +786,20 @@ class SystemLossPDE(eqx.Module):
763
786
  loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
764
787
  kw_only=True, default=None
765
788
  )
789
+ params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
766
790
 
767
791
  # following have init=False and are set in the __post_init__
768
792
  u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
769
793
  init=False
770
794
  )
771
- derivative_keys_u_dict: Dict[
772
- str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
773
- ] = eqx.field(init=False)
774
- derivative_keys_dyn_loss_dict: Dict[
775
- str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
776
- ] = eqx.field(init=False)
795
+ derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
796
+ eqx.field(init=False)
797
+ )
777
798
  u_dict_with_none: Dict[str, None] = eqx.field(init=False)
778
799
  # internally the loss weights are handled with a dictionary
779
800
  _loss_weights: Dict[str, dict] = eqx.field(init=False)
780
801
 
781
- def __post_init__(self, loss_weights):
802
+ def __post_init__(self, loss_weights=None, params_dict=None):
782
803
  # a dictionary that will be useful at different places
783
804
  self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
784
805
  # First, for all the optional dict,
@@ -818,24 +839,19 @@ class SystemLossPDE(eqx.Module):
818
839
  # iterating on dynamic_loss_dict. So each time we will require dome
819
840
  # derivative_keys_dict
820
841
 
821
- # but then if the user did not provide anything, we must at least have
822
- # a default value for the dynamic_loss_dict keys entries in
823
- # self.derivative_keys_dict since the computation of dynamic losses is
824
- # made without create a loss object that would provide the
825
- # default values
826
- for k in self.dynamic_loss_dict.keys():
842
+ # derivative keys for the u_constraints. Note that we create missing
843
+ # DerivativeKeysODE around a Params object and not ParamsDict
844
+ # this works because u_dict.keys == params_dict.nn_params.keys()
845
+ for k in self.u_dict.keys():
827
846
  if self.derivative_keys_dict[k] is None:
828
- try:
829
- if self.u_dict[k].eq_type == "statio_PDE":
830
- self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
831
- else:
832
- self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
833
- except KeyError: # We are in a key that is not in u_dict but in
834
- # dynamic_loss_dict
835
- if isinstance(self.dynamic_loss_dict[k], PDEStatio):
836
- self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
837
- else:
838
- self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
847
+ if self.u_dict[k].eq_type == "statio_PDE":
848
+ self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
849
+ params=params_dict.extract_params(k)
850
+ )
851
+ else:
852
+ self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
853
+ params=params_dict.extract_params(k)
854
+ )
839
855
 
840
856
  # Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
841
857
  if (
@@ -904,16 +920,11 @@ class SystemLossPDE(eqx.Module):
904
920
  f"got {self.u_dict[i].eq_type[i]}"
905
921
  )
906
922
 
907
- # for convenience in the tree_map of evaluate,
908
- # we separate the two derivative keys dict
909
- self.derivative_keys_dyn_loss_dict = {
910
- k: self.derivative_keys_dict[k]
911
- for k in self.dynamic_loss_dict.keys() & self.derivative_keys_dict.keys()
912
- }
913
- self.derivative_keys_u_dict = {
914
- k: self.derivative_keys_dict[k]
915
- for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
916
- }
923
+ # derivative keys for the dynamic loss. Note that we create a
924
+ # DerivativeKeysODE around a ParamsDict object because a whole
925
+ # params_dict is feed to DynamicLoss.evaluate functions (extract_params
926
+ # happen inside it)
927
+ self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
917
928
 
918
929
  # also make sure we only have PINNs or SPINNs
919
930
  if not (
@@ -1005,19 +1016,7 @@ class SystemLossPDE(eqx.Module):
1005
1016
  if self.u_dict.keys() != params_dict.nn_params.keys():
1006
1017
  raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
1007
1018
 
1008
- if isinstance(batch, PDEStatioBatch):
1009
- omega_batch, _ = batch.inside_batch, batch.border_batch
1010
- vmap_in_axes_x_or_x_t = (0,)
1011
-
1012
- batches = (omega_batch,)
1013
- elif isinstance(batch, PDENonStatioBatch):
1014
- times_batch = batch.times_x_inside_batch[:, 0:1]
1015
- omega_batch = batch.times_x_inside_batch[:, 1:]
1016
-
1017
- batches = (omega_batch, times_batch)
1018
- vmap_in_axes_x_or_x_t = (0, 0)
1019
- else:
1020
- raise ValueError("Wrong type of batch")
1019
+ vmap_in_axes = (0,)
1021
1020
 
1022
1021
  # Retrieve the optional eq_params_batch
1023
1022
  # and update eq_params with the latter
@@ -1025,7 +1024,6 @@ class SystemLossPDE(eqx.Module):
1025
1024
  if batch.param_batch_dict is not None:
1026
1025
  eq_params_batch_dict = batch.param_batch_dict
1027
1026
 
1028
- # TODO
1029
1027
  # feed the eq_params with the batch
1030
1028
  for k in eq_params_batch_dict.keys():
1031
1029
  params_dict.eq_params[k] = eq_params_batch_dict[k]
@@ -1034,14 +1032,18 @@ class SystemLossPDE(eqx.Module):
1034
1032
  batch.param_batch_dict, params_dict
1035
1033
  )
1036
1034
 
1037
- def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
1035
+ def dyn_loss_for_one_key(dyn_loss, loss_weight):
1038
1036
  """The function used in tree_map"""
1039
1037
  return dynamic_loss_apply(
1040
1038
  dyn_loss.evaluate,
1041
1039
  self.u_dict,
1042
- batches,
1043
- _set_derivatives(params_dict, derivative_key.dyn_loss),
1044
- vmap_in_axes_x_or_x_t + vmap_in_axes_params,
1040
+ (
1041
+ batch.domain_batch
1042
+ if isinstance(batch, PDEStatioBatch)
1043
+ else batch.domain_batch
1044
+ ),
1045
+ _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1046
+ vmap_in_axes + vmap_in_axes_params,
1045
1047
  loss_weight,
1046
1048
  u_type=type(list(self.u_dict.values())[0]),
1047
1049
  )
@@ -1049,7 +1051,6 @@ class SystemLossPDE(eqx.Module):
1049
1051
  dyn_loss_mse_dict = jax.tree_util.tree_map(
1050
1052
  dyn_loss_for_one_key,
1051
1053
  self.dynamic_loss_dict,
1052
- self.derivative_keys_dyn_loss_dict,
1053
1054
  self._loss_weights["dyn_loss"],
1054
1055
  is_leaf=lambda x: isinstance(
1055
1056
  x, (PDEStatio, PDENonStatio)
jinns/loss/__init__.py CHANGED
@@ -3,7 +3,7 @@ from ._LossODE import LossODE, SystemLossODE
3
3
  from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
4
4
  from ._DynamicLoss import (
5
5
  GeneralizedLotkaVolterra,
6
- BurgerEquation,
6
+ BurgersEquation,
7
7
  FPENonStatioLoss2D,
8
8
  OU_FPENonStatioLoss2D,
9
9
  FisherKPP,
@@ -19,9 +19,10 @@ from ._loss_weights import (
19
19
  )
20
20
 
21
21
  from ._operators import (
22
- _div_fwd,
23
- _div_rev,
24
- _laplacian_fwd,
25
- _laplacian_rev,
26
- _vectorial_laplacian,
22
+ divergence_fwd,
23
+ divergence_rev,
24
+ laplacian_fwd,
25
+ laplacian_rev,
26
+ vectorial_laplacian_fwd,
27
+ vectorial_laplacian_rev,
27
28
  )