jinns 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,7 +54,7 @@ def make_cartesian_product(
54
54
 
55
55
 
56
56
  def _reset_batch_idx_and_permute(
57
- operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
57
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
58
58
  ) -> tuple[Key, Float[Array, "n dimension"], Int]:
59
59
  key, domain, curr_idx, _, p = operands
60
60
  # resetting counter
@@ -78,7 +78,7 @@ def _reset_batch_idx_and_permute(
78
78
 
79
79
 
80
80
  def _increment_batch_idx(
81
- operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
81
+ operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
82
82
  ) -> tuple[Key, Float[Array, "n dimension"], Int]:
83
83
  key, domain, curr_idx, batch_size, _ = operands
84
84
  # simply increases counter and get the batch
@@ -13,8 +13,8 @@ from jax import grad
13
13
  import jax.numpy as jnp
14
14
  import equinox as eqx
15
15
 
16
- from jinns.utils._pinn import PINN
17
- from jinns.utils._spinn import SPINN
16
+ from jinns.nn._pinn import PINN
17
+ from jinns.nn._spinn_mlp import SPINN
18
18
 
19
19
  from jinns.utils._utils import get_grid
20
20
  from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
jinns/loss/_LossODE.py CHANGED
@@ -28,7 +28,7 @@ from jinns.parameters._params import (
28
28
  from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
29
29
  from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
30
30
  from jinns.loss._DynamicLossAbstract import ODE
31
- from jinns.utils._pinn import PINN
31
+ from jinns.nn._pinn import PINN
32
32
 
33
33
  if TYPE_CHECKING:
34
34
  from jinns.utils._types import *
jinns/loss/_LossPDE.py CHANGED
@@ -38,8 +38,8 @@ from jinns.loss._loss_weights import (
38
38
  LossWeightsPDEDict,
39
39
  )
40
40
  from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
41
- from jinns.utils._pinn import PINN
42
- from jinns.utils._spinn import SPINN
41
+ from jinns.nn._pinn import PINN
42
+ from jinns.nn._spinn import SPINN
43
43
  from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
44
44
 
45
45
 
@@ -92,12 +92,16 @@ class _LossPDEAbstract(eqx.Module):
92
92
  Note that it must be a slice and not an integer
93
93
  (but a preprocessing of the user provided argument takes care of it)
94
94
  norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
95
- Fixed sample point in the space over which to compute the
95
+ Monte-Carlo sample points for computing the
96
96
  normalization constant. Default is None.
97
- norm_int_length : float, default=None
98
- A float. Must be provided if `norm_samples` is provided. The domain area
99
- (or interval length in 1D) upon which we perform the numerical
100
- integration. Default None
97
+ norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
98
+ The importance sampling weights for Monte-Carlo integration of the
99
+ normalization constant. Must be provided if `norm_samples` is provided.
100
+ `norm_weights` should have the same leading dimension as
101
+ `norm_samples`.
102
+ Alternatively, the user can pass a float or an integer.
103
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
104
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
101
105
  obs_slice : slice, default=None
102
106
  slice object specifying the begininning/ending of the PINN output
103
107
  that is observed (this is then useful for multidim PINN). Default is None.
@@ -127,7 +131,9 @@ class _LossPDEAbstract(eqx.Module):
127
131
  norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
128
132
  kw_only=True, default=None
129
133
  )
130
- norm_int_length: float | None = eqx.field(kw_only=True, default=None)
134
+ norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
135
+ kw_only=True, default=None
136
+ )
131
137
  obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
132
138
 
133
139
  params: InitVar[Params] = eqx.field(kw_only=True, default=None)
@@ -251,8 +257,25 @@ class _LossPDEAbstract(eqx.Module):
251
257
  if not isinstance(self.omega_boundary_dim, slice):
252
258
  raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
253
259
 
254
- if self.norm_samples is not None and self.norm_int_length is None:
255
- raise ValueError("self.norm_samples and norm_int_length must be provided")
260
+ if self.norm_samples is not None:
261
+ if self.norm_weights is None:
262
+ raise ValueError(
263
+ "`norm_weights` must be provided when `norm_samples` is used!"
264
+ )
265
+ try:
266
+ assert self.norm_weights.shape[0] == self.norm_samples.shape[0]
267
+ except (AssertionError, AttributeError):
268
+ if isinstance(self.norm_weights, (int, float)):
269
+ self.norm_weights = jnp.array(
270
+ [self.norm_weights], dtype=jax.dtypes.canonicalize_dtype(float)
271
+ )
272
+ else:
273
+ raise ValueError(
274
+ "`norm_weights` should have the same leading dimension"
275
+ " as `norm_samples`,"
276
+ f" got shape {self.norm_weights.shape} and"
277
+ f" shape {self.norm_samples.shape}."
278
+ )
256
279
 
257
280
  @abc.abstractmethod
258
281
  def evaluate(
@@ -324,12 +347,16 @@ class LossPDEStatio(_LossPDEAbstract):
324
347
  Note that it must be a slice and not an integer
325
348
  (but a preprocessing of the user provided argument takes care of it)
326
349
  norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
327
- Fixed sample point in the space over which to compute the
350
+ Monte-Carlo sample points for computing the
328
351
  normalization constant. Default is None.
329
- norm_int_length : float, default=None
330
- A float. Must be provided if `norm_samples` is provided. The domain area
331
- (or interval length in 1D) upon which we perform the numerical
332
- integration. Default None
352
+ norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
353
+ The importance sampling weights for Monte-Carlo integration of the
354
+ normalization constant. Must be provided if `norm_samples` is provided.
355
+ `norm_weights` should have the same leading dimension as
356
+ `norm_samples`.
357
+ Alternatively, the user can pass a float or an integer.
358
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
359
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
333
360
  obs_slice : slice, default=None
334
361
  slice object specifying the begininning/ending of the PINN output
335
362
  that is observed (this is then useful for multidim PINN). Default is None.
@@ -431,7 +458,7 @@ class LossPDEStatio(_LossPDEAbstract):
431
458
  self._get_normalization_loss_batch(batch),
432
459
  _set_derivatives(params, self.derivative_keys.norm_loss),
433
460
  vmap_in_axes_params,
434
- self.norm_int_length,
461
+ self.norm_weights,
435
462
  self.loss_weights.norm_loss,
436
463
  )
437
464
  else:
@@ -549,12 +576,16 @@ class LossPDENonStatio(LossPDEStatio):
549
576
  Note that it must be a slice and not an integer
550
577
  (but a preprocessing of the user provided argument takes care of it)
551
578
  norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
552
- Fixed sample point in the space over which to compute the
579
+ Monte-Carlo sample points for computing the
553
580
  normalization constant. Default is None.
554
- norm_int_length : float, default=None
555
- A float. Must be provided if `norm_samples` is provided. The domain area
556
- (or interval length in 1D) upon which we perform the numerical
557
- integration. Default None
581
+ norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
582
+ The importance sampling weights for Monte-Carlo integration of the
583
+ normalization constant. Must be provided if `norm_samples` is provided.
584
+ `norm_weights` should have the same leading dimension as
585
+ `norm_samples`.
586
+ Alternatively, the user can pass a float or an integer.
587
+ These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
588
+ $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
558
589
  obs_slice : slice, default=None
559
590
  slice object specifying the begininning/ending of the PINN output
560
591
  that is observed (this is then useful for multidim PINN). Default is None.
@@ -730,15 +761,21 @@ class SystemLossPDE(eqx.Module):
730
761
  (default) then no temporal boundary condition is applied
731
762
  Must share the keys of `u_dict`
732
763
  norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
733
- A dict of fixed sample point in the space over which to compute the
734
- normalization constant. Default is None
735
- Must share the keys of `u_dict`
736
- norm_int_length_dict : Dict[str, float | None] | None, default=None
737
- A dict of Float. The domain area
738
- (or interval length in 1D) upon which we perform the numerical
739
- integration for each element of u_dict.
740
- Default is None
764
+ A dict of Monte-Carlo sample points for computing the
765
+ normalization constant. Default is None.
741
766
  Must share the keys of `u_dict`
767
+ norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
768
+ A dict of jnp.array with the same keys as `u_dict`. The importance
769
+ sampling weights for Monte-Carlo integration of the
770
+ normalization constant for each element of u_dict. Must be provided if
771
+ `norm_samples_dict` is provided.
772
+ `norm_weights_dict[key]` should have the same leading dimension as
773
+ `norm_samples_dict[key]` for each `key`.
774
+ Alternatively, the user can pass a float or an integer.
775
+ For each key, an array of similar shape to `norm_samples_dict[key]`
776
+ or shape `(1,)` is expected. These corresponds to the weights $w_k =
777
+ \frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
778
+ the Monte-Carlo samples. Default is None
742
779
  obs_slice_dict : Dict[str, slice | None] | None, default=None
743
780
  dict of obs_slice, with keys from `u_dict` to designate the
744
781
  output(s) channels that are forced to observed values, for each
@@ -774,9 +811,9 @@ class SystemLossPDE(eqx.Module):
774
811
  norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
775
812
  eqx.field(kw_only=True, default=None)
776
813
  )
777
- norm_int_length_dict: Dict[str, float | None] | None = eqx.field(
778
- kw_only=True, default=None
779
- )
814
+ norm_weights_dict: (
815
+ Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
816
+ ) = eqx.field(kw_only=True, default=None)
780
817
  obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
781
818
  kw_only=True, default=None, static=True
782
819
  )
@@ -819,8 +856,8 @@ class SystemLossPDE(eqx.Module):
819
856
  self.initial_condition_fun_dict = self.u_dict_with_none
820
857
  if self.norm_samples_dict is None:
821
858
  self.norm_samples_dict = self.u_dict_with_none
822
- if self.norm_int_length_dict is None:
823
- self.norm_int_length_dict = self.u_dict_with_none
859
+ if self.norm_weights_dict is None:
860
+ self.norm_weights_dict = self.u_dict_with_none
824
861
  if self.obs_slice_dict is None:
825
862
  self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
826
863
  if self.u_dict.keys() != self.obs_slice_dict.keys():
@@ -861,7 +898,7 @@ class SystemLossPDE(eqx.Module):
861
898
  or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
862
899
  or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
863
900
  or self.u_dict.keys() != self.norm_samples_dict.keys()
864
- or self.u_dict.keys() != self.norm_int_length_dict.keys()
901
+ or self.u_dict.keys() != self.norm_weights_dict.keys()
865
902
  ):
866
903
  raise ValueError("All the dicts concerning the PINNs should have same keys")
867
904
 
@@ -890,7 +927,7 @@ class SystemLossPDE(eqx.Module):
890
927
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
891
928
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
892
929
  norm_samples=self.norm_samples_dict[i],
893
- norm_int_length=self.norm_int_length_dict[i],
930
+ norm_weights=self.norm_weights_dict[i],
894
931
  obs_slice=self.obs_slice_dict[i],
895
932
  )
896
933
  elif self.u_dict[i].eq_type == "nonstatio_PDE":
@@ -911,7 +948,7 @@ class SystemLossPDE(eqx.Module):
911
948
  omega_boundary_dim=self.omega_boundary_dim_dict[i],
912
949
  initial_condition_fun=self.initial_condition_fun_dict[i],
913
950
  norm_samples=self.norm_samples_dict[i],
914
- norm_int_length=self.norm_int_length_dict[i],
951
+ norm_weights=self.norm_weights_dict[i],
915
952
  obs_slice=self.obs_slice_dict[i],
916
953
  )
917
954
  else:
@@ -1045,7 +1082,7 @@ class SystemLossPDE(eqx.Module):
1045
1082
  _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1046
1083
  vmap_in_axes + vmap_in_axes_params,
1047
1084
  loss_weight,
1048
- u_type=type(list(self.u_dict.values())[0]),
1085
+ u_type=list(self.u_dict.values())[0].__class__.__base__,
1049
1086
  )
1050
1087
 
1051
1088
  dyn_loss_mse_dict = jax.tree_util.tree_map(
@@ -13,8 +13,8 @@ from jax import vmap, grad
13
13
  import equinox as eqx
14
14
  from jinns.utils._utils import get_grid, _subtract_with_check
15
15
  from jinns.data._Batchs import *
16
- from jinns.utils._pinn import PINN
17
- from jinns.utils._spinn import SPINN
16
+ from jinns.nn._pinn import PINN
17
+ from jinns.nn._spinn import SPINN
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from jinns.utils._types import *
jinns/loss/_loss_utils.py CHANGED
@@ -19,9 +19,9 @@ from jinns.loss._boundary_conditions import (
19
19
  from jinns.utils._utils import _subtract_with_check, get_grid
20
20
  from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
21
21
  from jinns.parameters._params import _get_vmap_in_axes_params
22
- from jinns.utils._pinn import PINN
23
- from jinns.utils._spinn import SPINN
24
- from jinns.utils._hyperpinn import HYPERPINN
22
+ from jinns.nn._pinn import PINN
23
+ from jinns.nn._spinn import SPINN
24
+ from jinns.nn._hyperpinn import HyperPINN
25
25
  from jinns.data._Batchs import *
26
26
  from jinns.parameters._params import Params, ParamsDict
27
27
 
@@ -40,13 +40,13 @@ def dynamic_loss_apply(
40
40
  params: Params | ParamsDict,
41
41
  vmap_axes: tuple[int | None, ...],
42
42
  loss_weight: float | Float[Array, "dyn_loss_dimension"],
43
- u_type: PINN | HYPERPINN | None = None,
43
+ u_type: PINN | HyperPINN | None = None,
44
44
  ) -> float:
45
45
  """
46
46
  Sometimes when u is a lambda function a or dict we do not have access to
47
47
  its type here, hence the last argument
48
48
  """
49
- if u_type == PINN or u_type == HYPERPINN or isinstance(u, (PINN, HYPERPINN)):
49
+ if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
50
50
  v_dyn_loss = vmap(
51
51
  lambda batch, params: dyn_loss(
52
52
  batch, u, params # we must place the params at the end
@@ -75,23 +75,25 @@ def normalization_loss_apply(
75
75
  ),
76
76
  params: Params | ParamsDict,
77
77
  vmap_axes_params: tuple[int | None, ...],
78
- int_length: int,
78
+ norm_weights: Float[Array, "nb_norm_samples"],
79
79
  loss_weight: float,
80
80
  ) -> float:
81
81
  """
82
82
  Note the squeezing on each result. We expect unidimensional *PINN since
83
83
  they represent probability distributions
84
84
  """
85
- if isinstance(u, (PINN, HYPERPINN)):
85
+ if isinstance(u, (PINN, HyperPINN)):
86
86
  if len(batches) == 1:
87
87
  v_u = vmap(
88
- lambda b: u(b)[u.slice_solution],
88
+ lambda *b: u(*b)[u.slice_solution],
89
89
  (0,) + vmap_axes_params,
90
90
  0,
91
91
  )
92
92
  res = v_u(*batches, params)
93
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
94
+ # Monte-Carlo integration using importance sampling
93
95
  mse_norm_loss = loss_weight * (
94
- jnp.abs(jnp.mean(res.squeeze()) * int_length - 1) ** 2
96
+ jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
95
97
  )
96
98
  else:
97
99
  # NOTE this cartesian product is costly
@@ -107,20 +109,23 @@ def normalization_loss_apply(
107
109
  in_axes=(0,) + vmap_axes_params,
108
110
  )
109
111
  res = v_u(batches, params)
110
- # Over all the times t, we perform a integration
112
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
113
+ # For all times t, we perform an integration. Then we average the
114
+ # losses over times.
111
115
  mse_norm_loss = loss_weight * jnp.mean(
112
- jnp.abs(jnp.mean(res.squeeze(), axis=-1) * int_length - 1) ** 2
116
+ jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
113
117
  )
114
118
  elif isinstance(u, SPINN):
115
119
  if len(batches) == 1:
116
120
  res = u(*batches, params)
121
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
117
122
  mse_norm_loss = (
118
123
  loss_weight
119
124
  * jnp.abs(
120
125
  jnp.mean(
121
126
  res.squeeze(),
122
127
  )
123
- * int_length
128
+ * norm_weights
124
129
  - 1
125
130
  )
126
131
  ** 2
@@ -134,6 +139,7 @@ def normalization_loss_apply(
134
139
  ),
135
140
  params,
136
141
  )
142
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
137
143
  # the outer mean() below is for the times stamps
138
144
  mse_norm_loss = loss_weight * jnp.mean(
139
145
  jnp.abs(
@@ -141,7 +147,7 @@ def normalization_loss_apply(
141
147
  res.squeeze(),
142
148
  axis=(d + 1 for d in range(res.ndim - 2)),
143
149
  )
144
- * int_length
150
+ * norm_weights
145
151
  - 1
146
152
  )
147
153
  ** 2
@@ -230,7 +236,7 @@ def observations_loss_apply(
230
236
  obs_slice: slice,
231
237
  ) -> float:
232
238
  # TODO implement for SPINN
233
- if isinstance(u, (PINN, HYPERPINN)):
239
+ if isinstance(u, (PINN, HyperPINN)):
234
240
  v_u = vmap(
235
241
  lambda *args: u(*args)[u.slice_solution],
236
242
  vmap_axes,
@@ -264,7 +270,7 @@ def initial_condition_apply(
264
270
  ) -> float:
265
271
  n = omega_batch.shape[0]
266
272
  t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
267
- if isinstance(u, (PINN, HYPERPINN)):
273
+ if isinstance(u, (PINN, HyperPINN)):
268
274
  v_u_t0 = vmap(
269
275
  lambda t0_x, params: _subtract_with_check(
270
276
  initial_condition_fun(t0_x[1:]),
jinns/loss/_operators.py CHANGED
@@ -9,8 +9,6 @@ import jax.numpy as jnp
9
9
  from jax import grad
10
10
  import equinox as eqx
11
11
  from jaxtyping import Float, Array
12
- from jinns.utils._pinn import PINN
13
- from jinns.utils._spinn import SPINN
14
12
  from jinns.parameters._params import Params
15
13
 
16
14
 
jinns/nn/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ from ._save_load import save_pinn, load_pinn
2
+ from ._pinn import PINN
3
+ from ._spinn import SPINN
4
+ from ._mlp import PINN_MLP, MLP
5
+ from ._spinn_mlp import SPINN_MLP, SMLP
6
+ from ._hyperpinn import HyperPINN
7
+ from ._ppinn import PPINN_MLP