jinns 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. jinns/data/_AbstractDataGenerator.py +1 -1
  2. jinns/data/_Batchs.py +47 -13
  3. jinns/data/_CubicMeshPDENonStatio.py +55 -34
  4. jinns/data/_CubicMeshPDEStatio.py +63 -35
  5. jinns/data/_DataGeneratorODE.py +48 -22
  6. jinns/data/_DataGeneratorObservations.py +75 -32
  7. jinns/data/_DataGeneratorParameter.py +152 -101
  8. jinns/data/__init__.py +2 -1
  9. jinns/data/_utils.py +22 -10
  10. jinns/loss/_DynamicLoss.py +21 -20
  11. jinns/loss/_DynamicLossAbstract.py +51 -36
  12. jinns/loss/_LossODE.py +139 -184
  13. jinns/loss/_LossPDE.py +440 -358
  14. jinns/loss/_abstract_loss.py +60 -25
  15. jinns/loss/_loss_components.py +4 -25
  16. jinns/loss/_loss_weight_updates.py +6 -7
  17. jinns/loss/_loss_weights.py +34 -35
  18. jinns/nn/_abstract_pinn.py +0 -2
  19. jinns/nn/_hyperpinn.py +34 -23
  20. jinns/nn/_mlp.py +5 -4
  21. jinns/nn/_pinn.py +1 -16
  22. jinns/nn/_ppinn.py +5 -16
  23. jinns/nn/_save_load.py +11 -4
  24. jinns/nn/_spinn.py +1 -16
  25. jinns/nn/_spinn_mlp.py +5 -5
  26. jinns/nn/_utils.py +33 -38
  27. jinns/parameters/__init__.py +3 -1
  28. jinns/parameters/_derivative_keys.py +99 -41
  29. jinns/parameters/_params.py +50 -25
  30. jinns/solver/_solve.py +3 -3
  31. jinns/utils/_DictToModuleMeta.py +66 -0
  32. jinns/utils/_ItemizableModule.py +19 -0
  33. jinns/utils/__init__.py +2 -1
  34. jinns/utils/_types.py +25 -15
  35. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  36. jinns-1.6.0.dist-info/RECORD +57 -0
  37. jinns-1.5.1.dist-info/RECORD +0 -55
  38. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  39. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  40. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  41. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
9
9
  import equinox as eqx
10
10
  import jax
11
11
  import jax.numpy as jnp
12
- from jaxtyping import Key, Array, Float
12
+ from jaxtyping import PRNGKeyArray, Array, Float
13
13
  from jinns.data._Batchs import ODEBatch
14
14
  from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
15
15
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
@@ -24,7 +24,7 @@ class DataGeneratorODE(AbstractDataGenerator):
24
24
 
25
25
  Parameters
26
26
  ----------
27
- key : Key
27
+ key : PRNGKeyArray
28
28
  Jax random key to sample new time points and to shuffle batches
29
29
  nt : int
30
30
  The number of total time points that will be divided in
@@ -42,10 +42,10 @@ class DataGeneratorODE(AbstractDataGenerator):
42
42
  The method that generates the `nt` time points. `grid` means
43
43
  regularly spaced points over the domain. `uniform` means uniformly
44
44
  sampled points over the domain
45
- rar_parameters : RarParameterDict, default=None
45
+ rar_parameters : None | RarParameterDict, default=None
46
46
  A TypedDict to specify the Residual Adaptative Resampling procedure. See
47
47
  the docstring from RarParameterDict
48
- n_start : int, default=None
48
+ n_start : None | int, default=None
49
49
  Defaults to None. The effective size of nt used at start time.
50
50
  This value must be
51
51
  provided when rar_parameters is not None. Otherwise we set internally
@@ -54,25 +54,43 @@ class DataGeneratorODE(AbstractDataGenerator):
54
54
  then corresponds to the initial number of points we train the PINN.
55
55
  """
56
56
 
57
- key: Key = eqx.field(kw_only=True)
58
- nt: int = eqx.field(kw_only=True, static=True)
59
- tmin: Float = eqx.field(kw_only=True)
60
- tmax: Float = eqx.field(kw_only=True)
61
- temporal_batch_size: int | None = eqx.field(static=True, default=None, kw_only=True)
62
- method: str = eqx.field(
63
- static=True, kw_only=True, default_factory=lambda: "uniform"
64
- )
65
- rar_parameters: dict[str, int] = eqx.field(default=None, kw_only=True)
66
- n_start: int = eqx.field(static=True, default=None, kw_only=True)
67
-
68
- # all the init=False fields are set in __post_init__
57
+ key: PRNGKeyArray
58
+ nt: int = eqx.field(static=True)
59
+ tmin: float
60
+ tmax: float
61
+ temporal_batch_size: int | None = eqx.field(static=True)
62
+ method: str = eqx.field(static=True)
63
+ rar_parameters: None | dict[str, int]
64
+ n_start: None | int
65
+
66
+ # --- Below fields are not passed as arguments to __init__
69
67
  p: Float[Array, " nt 1"] | None = eqx.field(init=False)
70
68
  rar_iter_from_last_sampling: int | None = eqx.field(init=False)
71
69
  rar_iter_nb: int | None = eqx.field(init=False)
72
70
  curr_time_idx: int = eqx.field(init=False)
73
71
  times: Float[Array, " nt 1"] = eqx.field(init=False)
74
72
 
75
- def __post_init__(self):
73
+ def __init__(
74
+ self,
75
+ *,
76
+ key: PRNGKeyArray,
77
+ nt: int,
78
+ tmin: float,
79
+ tmax: float,
80
+ temporal_batch_size: int | None,
81
+ method: str = "uniform",
82
+ rar_parameters: None | dict[str, int] = None,
83
+ n_start: None | int = None,
84
+ ):
85
+ self.key = key
86
+ self.nt = nt
87
+ self.tmin = tmin
88
+ self.tmax = tmax
89
+ self.temporal_batch_size = temporal_batch_size
90
+ self.method = method
91
+ self.n_start = n_start
92
+ self.rar_parameters = rar_parameters
93
+
76
94
  (
77
95
  self.n_start,
78
96
  self.p,
@@ -97,7 +115,7 @@ class DataGeneratorODE(AbstractDataGenerator):
97
115
  # above way for the key.
98
116
 
99
117
  def sample_in_time_domain(
100
- self, key: Key, sample_size: int | None = None
118
+ self, key: PRNGKeyArray, sample_size: int | None = None
101
119
  ) -> Float[Array, " nt 1"]:
102
120
  return jax.random.uniform(
103
121
  key,
@@ -106,7 +124,9 @@ class DataGeneratorODE(AbstractDataGenerator):
106
124
  maxval=self.tmax,
107
125
  )
108
126
 
109
- def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, " nt"]]:
127
+ def generate_time_data(
128
+ self, key: PRNGKeyArray
129
+ ) -> tuple[PRNGKeyArray, Float[Array, " nt"]]:
110
130
  """
111
131
  Construct a complete set of `self.nt` time points according to the
112
132
  specified `self.method`
@@ -125,7 +145,11 @@ class DataGeneratorODE(AbstractDataGenerator):
125
145
  def _get_time_operands(
126
146
  self,
127
147
  ) -> tuple[
128
- Key, Float[Array, " nt 1"], int, int | None, Float[Array, " nt 1"] | None
148
+ PRNGKeyArray,
149
+ Float[Array, " nt 1"],
150
+ int,
151
+ int | None,
152
+ Float[Array, " nt 1"] | None,
129
153
  ]:
130
154
  return (
131
155
  self.key,
@@ -150,7 +174,7 @@ class DataGeneratorODE(AbstractDataGenerator):
150
174
  bend = bstart + self.temporal_batch_size
151
175
 
152
176
  # Compute the effective number of used collocation points
153
- if self.rar_parameters is not None:
177
+ if self.rar_parameters is not None and self.n_start is not None:
154
178
  nt_eff = (
155
179
  self.n_start
156
180
  + self.rar_iter_nb # type: ignore
@@ -167,7 +191,9 @@ class DataGeneratorODE(AbstractDataGenerator):
167
191
  # handled above
168
192
  )
169
193
  new = eqx.tree_at(
170
- lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
194
+ lambda m: (m.key, m.times, m.curr_time_idx), # type: ignore
195
+ self,
196
+ new_attributes,
171
197
  )
172
198
 
173
199
  # commands below are equivalent to
@@ -8,10 +8,23 @@ from __future__ import (
8
8
  import equinox as eqx
9
9
  import jax
10
10
  import jax.numpy as jnp
11
- from jaxtyping import Key, Int, Array, Float
11
+ from typing import TYPE_CHECKING, Self
12
+ from jaxtyping import PRNGKeyArray, Int, Array, Float
12
13
  from jinns.data._Batchs import ObsBatchDict
13
14
  from jinns.data._utils import _reset_or_increment
14
15
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
16
+ from jinns.parameters._params import EqParams
17
+
18
+ if TYPE_CHECKING:
19
+ # imports only used in type hints
20
+ InputEqParams = (
21
+ dict[str, Float[Array, " n_obs"]] | dict[str, Float[Array, " n_obs 1"]]
22
+ ) | None
23
+
24
+ # Note that the lambda functions used below are with type: ignore just
25
+ # because the lambda are not type annotated, but there is no proper way
26
+ # to do this and we should assign the lambda to a type hinted variable
27
+ # before hand: this is not practical, let us not get mad at this
15
28
 
16
29
 
17
30
  class DataGeneratorObservations(AbstractDataGenerator):
@@ -21,7 +34,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
21
34
 
22
35
  Parameters
23
36
  ----------
24
- key : Key
37
+ key : PRNGKeyArray
25
38
  Jax random key to shuffle batches
26
39
  obs_batch_size : int | None
27
40
  The size of the batch of randomly selected points among
@@ -39,7 +52,8 @@ class DataGeneratorObservations(AbstractDataGenerator):
39
52
  observed_eq_params : dict[str, Float[Array, " n_obs 1"]], default={}
40
53
  A dict with keys corresponding to
41
54
  the parameter name. The keys must match the keys in
42
- `params["eq_params"]`. The values are jnp.array with 2 dimensions
55
+ `params["eq_params"]`, ie., if only some parameters are observed, other
56
+ keys **must still appear with None as value**. The values are jnp.array with 2 dimensions
43
57
  with values corresponding to the parameter value for which we also
44
58
  have observed_pinn_in and observed_values. Hence the first
45
59
  dimension must be aligned with observed_pinn_in and observed_values.
@@ -54,30 +68,37 @@ class DataGeneratorObservations(AbstractDataGenerator):
54
68
  arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
55
69
  """
56
70
 
57
- key: Key
71
+ key: PRNGKeyArray
58
72
  obs_batch_size: int | None = eqx.field(static=True)
59
73
  observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
60
74
  observed_values: Float[Array, " n_obs nb_pinn_out"]
61
- observed_eq_params: dict[str, Float[Array, " n_obs 1"]] = eqx.field(
62
- static=True, default_factory=lambda: {}
63
- )
64
- sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
75
+ observed_eq_params: eqx.Module | None
76
+ sharding_device: jax.sharding.Sharding | None # = eqx.field(static=True)
65
77
 
66
78
  n: int = eqx.field(init=False, static=True)
67
79
  curr_idx: int = eqx.field(init=False)
68
80
  indices: Array = eqx.field(init=False)
69
81
 
70
- def __post_init__(self):
82
+ def __init__(
83
+ self,
84
+ *,
85
+ key: PRNGKeyArray,
86
+ obs_batch_size: int | None = None,
87
+ observed_pinn_in: Float[Array, " n_obs nb_pinn_in"],
88
+ observed_values: Float[Array, " n_obs nb_pinn_out"],
89
+ observed_eq_params: InputEqParams | None = None,
90
+ sharding_device: jax.sharding.Sharding | None = None,
91
+ ) -> None:
92
+ super().__init__()
93
+ self.key = key
94
+ self.obs_batch_size = obs_batch_size
95
+ self.observed_pinn_in = observed_pinn_in
96
+ self.observed_values = observed_values
97
+
71
98
  if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
72
99
  raise ValueError(
73
100
  "self.observed_pinn_in and self.observed_values must have same first axis"
74
101
  )
75
- for _, v in self.observed_eq_params.items():
76
- if v.shape[0] != self.observed_pinn_in.shape[0]:
77
- raise ValueError(
78
- "self.observed_pinn_in and the values of"
79
- " self.observed_eq_params must have the same first axis"
80
- )
81
102
  if len(self.observed_pinn_in.shape) == 1:
82
103
  self.observed_pinn_in = self.observed_pinn_in[:, None]
83
104
  if self.observed_pinn_in.ndim > 2:
@@ -86,16 +107,32 @@ class DataGeneratorObservations(AbstractDataGenerator):
86
107
  self.observed_values = self.observed_values[:, None]
87
108
  if self.observed_values.ndim > 2:
88
109
  raise ValueError("self.observed_values must have 2 dimensions")
89
- for k, v in self.observed_eq_params.items():
90
- if len(v.shape) == 1:
91
- self.observed_eq_params[k] = v[:, None]
92
- if len(v.shape) > 2:
93
- raise ValueError(
94
- "Each value of observed_eq_params must have 2 dimensions"
95
- )
110
+
111
+ if observed_eq_params is not None:
112
+ for _, v in observed_eq_params.items():
113
+ if v.shape[0] != self.observed_pinn_in.shape[0]:
114
+ raise ValueError(
115
+ "self.observed_pinn_in and the values of"
116
+ " self.observed_eq_params must have the same first axis"
117
+ )
118
+ for k, v in observed_eq_params.items():
119
+ if len(v.shape) == 1:
120
+ # Reshape to add an axis for 1-d Array
121
+ observed_eq_params[k] = v[:, None]
122
+ if len(v.shape) > 2:
123
+ raise ValueError(
124
+ f"Each key of observed_eq_params must have 2"
125
+ f"dimensions, key {k} had shape {v.shape}."
126
+ )
127
+ # Convert the dict of observed parameters to the internal `EqParams`
128
+ # class used by Jinns.
129
+ self.observed_eq_params = EqParams(observed_eq_params, "EqParams")
130
+ else:
131
+ self.observed_eq_params = observed_eq_params
96
132
 
97
133
  self.n = self.observed_pinn_in.shape[0]
98
134
 
135
+ self.sharding_device = sharding_device
99
136
  if self.sharding_device is not None:
100
137
  self.observed_pinn_in = jax.lax.with_sharding_constraint(
101
138
  self.observed_pinn_in, self.sharding_device
@@ -126,7 +163,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
126
163
  self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
127
164
  # the call to _reset_batch_idx_and_permute in legacy DG
128
165
 
129
- def _get_operands(self) -> tuple[Key, Int[Array, " n"], int, int | None, None]:
166
+ def _get_operands(
167
+ self,
168
+ ) -> tuple[PRNGKeyArray, Int[Array, " n"], int, int | None, None]:
130
169
  return (
131
170
  self.key,
132
171
  self.indices,
@@ -137,17 +176,19 @@ class DataGeneratorObservations(AbstractDataGenerator):
137
176
 
138
177
  def obs_batch(
139
178
  self,
140
- ) -> tuple[DataGeneratorObservations, ObsBatchDict]:
179
+ ) -> tuple[Self, ObsBatchDict]:
141
180
  """
142
181
  Return an update DataGeneratorObservations instance and an ObsBatchDict
143
182
  """
144
183
  if self.obs_batch_size is None or self.obs_batch_size == self.n:
145
184
  # Avoid unnecessary reshuffling
146
- return self, {
147
- "pinn_in": self.observed_pinn_in,
148
- "val": self.observed_values,
149
- "eq_params": self.observed_eq_params,
150
- }
185
+ return self, ObsBatchDict(
186
+ {
187
+ "pinn_in": self.observed_pinn_in,
188
+ "val": self.observed_values,
189
+ "eq_params": self.observed_eq_params,
190
+ }
191
+ )
151
192
 
152
193
  new_attributes = _reset_or_increment(
153
194
  self.curr_idx + self.obs_batch_size,
@@ -157,7 +198,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
157
198
  # handled above
158
199
  )
159
200
  new = eqx.tree_at(
160
- lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
201
+ lambda m: (m.key, m.indices, m.curr_idx), # type: ignore
202
+ self,
203
+ new_attributes,
161
204
  )
162
205
 
163
206
  minib_indices = jax.lax.dynamic_slice(
@@ -174,7 +217,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
174
217
  new.observed_values, minib_indices, unique_indices=True, axis=0
175
218
  ),
176
219
  "eq_params": jax.tree_util.tree_map(
177
- lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
220
+ lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0), # type: ignore
178
221
  new.observed_eq_params,
179
222
  ),
180
223
  }
@@ -182,7 +225,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
182
225
 
183
226
  def get_batch(
184
227
  self,
185
- ) -> tuple[DataGeneratorObservations, ObsBatchDict]:
228
+ ) -> tuple[Self, ObsBatchDict]:
186
229
  """
187
230
  Generic method to return a batch
188
231
  """
@@ -5,12 +5,23 @@ Define the DataGenerators modules
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
+ from typing import Self
8
9
  import equinox as eqx
9
10
  import jax
10
11
  import jax.numpy as jnp
11
- from jaxtyping import Key, Array, Float
12
+ from jaxtyping import PRNGKeyArray, Array, Float
12
13
  from jinns.data._utils import _reset_or_increment
13
14
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
15
+ from jinns.utils._DictToModuleMeta import DictToModuleMeta
16
+
17
+
18
+ class DGParams(metaclass=DictToModuleMeta):
19
+ """
20
+ However, static type checkers cannot know that DGParams inherit from
21
+ eqx.Module and explicit casting to the latter class will be needed
22
+ """
23
+
24
+ pass
14
25
 
15
26
 
16
27
  class DataGeneratorParameter(AbstractDataGenerator):
@@ -21,9 +32,8 @@ class DataGeneratorParameter(AbstractDataGenerator):
21
32
 
22
33
  Parameters
23
34
  ----------
24
- keys : Key | dict[str, Key]
25
- Jax random key to sample new time points and to shuffle batches
26
- or a dict of Jax random keys with key entries from param_ranges
35
+ key : PRNGKeyArray
36
+ Jax random key to sample new time points and to shuffle batches.
27
37
  n : int
28
38
  The number of total points that will be divided in
29
39
  batches. Batches are made so that each data point is seen only
@@ -58,71 +68,86 @@ class DataGeneratorParameter(AbstractDataGenerator):
58
68
  Defaults to None.
59
69
  """
60
70
 
61
- keys: Key | dict[str, Key]
71
+ key: PRNGKeyArray
62
72
  n: int = eqx.field(static=True)
63
- param_batch_size: int | None = eqx.field(static=True, default=None)
64
- param_ranges: dict[str, tuple[Float, Float]] = eqx.field(
65
- static=True, default_factory=lambda: {}
66
- )
67
- method: str = eqx.field(static=True, default="uniform")
68
- user_data: dict[str, Float[Array, " n"]] | None = eqx.field(
69
- default_factory=lambda: {}
70
- )
73
+ param_batch_size: int | None = eqx.field(static=True)
74
+ param_ranges: dict[str, tuple[Float, Float]] = eqx.field(static=True)
75
+ method: str = eqx.field(static=True)
76
+ user_data: dict[str, Float[Array, " n"]]
77
+
78
+ # --- Below fields are not passed as arguments to __init__
79
+ _all_params_keys: set[str] = eqx.field(init=False, static=True)
80
+ curr_param_idx: eqx.Module | None = eqx.field(init=False)
81
+ param_n_samples: eqx.Module = eqx.field(init=False)
71
82
 
72
- curr_param_idx: dict[str, int] = eqx.field(init=False)
73
- param_n_samples: dict[str, Array] = eqx.field(init=False)
83
+ def __init__(
84
+ self,
85
+ *,
86
+ key: PRNGKeyArray,
87
+ n: int,
88
+ param_batch_size: int | None,
89
+ param_ranges: dict[str, tuple[Float, Float]] = {},
90
+ method: str = "uniform",
91
+ user_data: dict[str, Float[Array, " n"]] = {},
92
+ ):
93
+ self.key = key
94
+ self.n = n
95
+ self.param_batch_size = param_batch_size
96
+ self.param_ranges = param_ranges
97
+ self.method = method
98
+ self.user_data = user_data
99
+
100
+ _all_keys = set().union(self.param_ranges, self.user_data)
101
+ self._all_params_keys = _all_keys
74
102
 
75
- def __post_init__(self):
76
- if self.user_data is None:
77
- self.user_data = {}
78
- if self.param_ranges is None:
79
- self.param_ranges = {}
80
103
  if self.param_batch_size is not None and self.n < self.param_batch_size:
81
104
  raise ValueError(
82
105
  f"Number of data points ({self.n}) is smaller than the"
83
106
  f"number of batch points ({self.param_batch_size})."
84
107
  )
85
- if not isinstance(self.keys, dict):
86
- all_keys = set().union(self.param_ranges, self.user_data)
87
- self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
108
+
109
+ # NOTE from jinns > v1.5.1 we work with eqx.Module
110
+ # because eq_params is not a dict anymore.
111
+ # We have to use a different class from the publicly exposed EqParams
112
+ # because fields(EqParams) are not necessarily all present in the
113
+ # datagenerator, which would cause eqx.Module to error.
114
+
115
+ # 1) Call self.generate_data() to generate a dictionnary that merges the scattered data between `user_data` and `param_ranges`
116
+ self.key, _param_n_samples = self.generate_data(self.key)
117
+
118
+ # 2) Use the dictionnary to populate the field of the eqx.Module.
119
+ self.param_n_samples = DGParams(_param_n_samples, "DGParams")
88
120
 
89
121
  if self.param_batch_size is None:
90
- self.curr_param_idx = None # type: ignore
122
+ self.curr_param_idx = None
91
123
  else:
92
- self.curr_param_idx = {}
93
- for k in self.keys.keys():
94
- self.curr_param_idx[k] = self.n + self.param_batch_size
95
- # to be sure there is a shuffling at first get_batch()
124
+ curr_idx = self.n + self.param_batch_size
125
+ param_keys_and_curr_idx = {k: curr_idx for k in self._all_params_keys}
96
126
 
97
- # The call to self.generate_data() creates
98
- # the dict self.param_n_samples and then we will only use this one
99
- # because it merges the scattered data between `user_data` and
100
- # `param_ranges`
101
- self.keys, self.param_n_samples = self.generate_data(self.keys)
127
+ self.curr_param_idx = DGParams(param_keys_and_curr_idx)
102
128
 
103
129
  def generate_data(
104
- self, keys: dict[str, Key]
105
- ) -> tuple[dict[str, Key], dict[str, Float[Array, " n"]]]:
130
+ self, key: PRNGKeyArray
131
+ ) -> tuple[PRNGKeyArray, dict[str, Float[Array, " n 1"]]]:
106
132
  """
107
133
  Generate parameter samples, either through generation
108
134
  or using user-provided data.
109
135
  """
110
136
  param_n_samples = {}
111
137
 
112
- all_keys = set().union(
113
- self.param_ranges,
114
- self.user_data, # type: ignore this has been handled in post_init
115
- )
116
- for k in all_keys:
138
+ # Some of the subkeys might not be used cause of user-provided data.
139
+ # This is not a big deal and simpler like that.
140
+ key, *subkeys = jax.random.split(key, len(self._all_params_keys) + 1)
141
+ for i, k in enumerate(self._all_params_keys):
117
142
  if self.user_data and k in self.user_data.keys():
118
- if self.user_data[k].shape == (self.n, 1):
119
- param_n_samples[k] = self.user_data[k]
120
- if self.user_data[k].shape == (self.n,):
121
- param_n_samples[k] = self.user_data[k][:, None]
122
- else:
123
- raise ValueError(
143
+ try:
144
+ param_n_samples[k] = self.user_data[k].reshape((self.n, 1))
145
+ except TypeError:
146
+ shape = self.user_data[k].shape
147
+ raise TypeError(
124
148
  "Wrong shape for user provided parameters"
125
- f" in user_data dictionary at key='{k}'"
149
+ f" in user_data dictionary at key='{k}' got {shape} "
150
+ f"and expected {(self.n, 1)}."
126
151
  )
127
152
  else:
128
153
  if self.method == "grid":
@@ -132,75 +157,101 @@ class DataGeneratorParameter(AbstractDataGenerator):
132
157
  param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
133
158
  elif self.method == "uniform":
134
159
  xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
135
- keys[k], subkey = jax.random.split(keys[k], 2)
136
160
  param_n_samples[k] = jax.random.uniform(
137
- subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
161
+ subkeys[i], shape=(self.n, 1), minval=xmin, maxval=xmax
138
162
  )
139
163
  else:
140
164
  raise ValueError("Method " + self.method + " is not implemented.")
141
165
 
142
- return keys, param_n_samples
143
-
144
- def _get_param_operands(
145
- self, k: str
146
- ) -> tuple[Key, Float[Array, " n"], int, int | None, None]:
147
- return (
148
- self.keys[k],
149
- self.param_n_samples[k],
150
- self.curr_param_idx[k],
151
- self.param_batch_size,
152
- None,
153
- )
166
+ return key, param_n_samples
154
167
 
155
- def param_batch(self):
168
+ def param_batch(self) -> tuple[Self, eqx.Module]:
156
169
  """
157
- Return a dictionary with batches of parameters
158
- If all the batches have been seen, we reshuffle them,
159
- otherwise we just return the next unseen batch.
170
+ Return an `eqx.Module` with batches of parameters at its leafs.
171
+ If all the batches have been seen, we reshuffle them (or rather
172
+ their indices), otherwise we just return the next unseen batch.
160
173
  """
161
174
 
162
175
  if self.param_batch_size is None or self.param_batch_size == self.n:
176
+ # Full batch mode: nothing to do.
163
177
  return self, self.param_n_samples
178
+ else:
179
+
180
+ def _reset_or_increment_wrapper(
181
+ param_k: Array, idx_k: int, key_k: PRNGKeyArray
182
+ ):
183
+ everything_but_key = _reset_or_increment(
184
+ idx_k + self.param_batch_size, # type: ignore
185
+ self.n,
186
+ (key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
187
+ )[1:]
188
+ return everything_but_key
164
189
 
165
- def _reset_or_increment_wrapper(param_k, idx_k, key_k):
166
- return _reset_or_increment(
167
- idx_k + self.param_batch_size,
168
- self.n,
169
- (key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
170
- # ignore since the case self.param_batch_size is None has been
171
- # handled above
190
+ new_key, *subkeys = jax.random.split(
191
+ self.key, len(self._all_params_keys) + 1
172
192
  )
173
193
 
174
- res = jax.tree_util.tree_map(
175
- _reset_or_increment_wrapper,
176
- self.param_n_samples,
177
- self.curr_param_idx,
178
- self.keys,
179
- )
180
- # we must transpose the pytrees because keys are merged in res
181
- # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
182
- new_attributes = jax.tree_util.tree_transpose(
183
- jax.tree_util.tree_structure(self.keys),
184
- jax.tree_util.tree_structure([0, 0, 0]),
185
- res,
186
- )
187
-
188
- new = eqx.tree_at(
189
- lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
190
- self,
191
- new_attributes,
192
- )
193
-
194
- return new, jax.tree_util.tree_map(
195
- lambda p, q: jax.lax.dynamic_slice(
196
- p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
197
- ),
198
- new.param_n_samples,
199
- new.curr_param_idx,
200
- )
201
-
202
- def get_batch(self):
194
+ # From PRNGKeyArray to a pytree of keys with adequate structure
195
+ subkeys = jax.tree.unflatten(
196
+ jax.tree.structure(self.param_n_samples), subkeys
197
+ )
198
+
199
+ res = jax.tree.map(
200
+ _reset_or_increment_wrapper,
201
+ self.param_n_samples,
202
+ self.curr_param_idx,
203
+ subkeys,
204
+ )
205
+ # we must transpose the pytrees because both params and curr_idx # are merged in res
206
+ # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
207
+
208
+ new_attributes = jax.tree.transpose(
209
+ jax.tree.structure(self.param_n_samples),
210
+ jax.tree.structure([0, 0]),
211
+ res,
212
+ )
213
+
214
+ new = eqx.tree_at(
215
+ lambda m: (m.param_n_samples, m.curr_param_idx),
216
+ self,
217
+ new_attributes,
218
+ )
219
+
220
+ new = eqx.tree_at(lambda m: m.key, new, new_key)
221
+
222
+ return new, jax.tree_util.tree_map(
223
+ lambda p, q: jax.lax.dynamic_slice(
224
+ p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
225
+ ),
226
+ new.param_n_samples,
227
+ new.curr_param_idx,
228
+ )
229
+
230
+ def get_batch(self) -> tuple[Self, eqx.Module]:
203
231
  """
204
232
  Generic method to return a batch
205
233
  """
206
234
  return self.param_batch()
235
+
236
+
237
+ if __name__ == "__main__":
238
+ key = jax.random.PRNGKey(2)
239
+ key, subkey = jax.random.split(key)
240
+
241
+ n = 64
242
+ param_batch_size = 32
243
+ method = "uniform"
244
+ param_ranges = {"theta": (10.0, 11.0)}
245
+ user_data = {"nu": jnp.ones((n, 1))}
246
+
247
+ x = DataGeneratorParameter(
248
+ key=subkey,
249
+ n=n,
250
+ param_batch_size=param_batch_size,
251
+ param_ranges=param_ranges,
252
+ method=method,
253
+ user_data=user_data,
254
+ )
255
+ print(x.key)
256
+ x, batch = x.get_batch()
257
+ print(x.key)
jinns/data/__init__.py CHANGED
@@ -2,7 +2,7 @@ from ._DataGeneratorODE import DataGeneratorODE
2
2
  from ._CubicMeshPDEStatio import CubicMeshPDEStatio
3
3
  from ._CubicMeshPDENonStatio import CubicMeshPDENonStatio
4
4
  from ._DataGeneratorObservations import DataGeneratorObservations
5
- from ._DataGeneratorParameter import DataGeneratorParameter
5
+ from ._DataGeneratorParameter import DataGeneratorParameter, DGParams
6
6
  from ._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
7
7
 
8
8
  from ._utils import append_obs_batch, append_param_batch
@@ -12,6 +12,7 @@ __all__ = [
12
12
  "CubicMeshPDEStatio",
13
13
  "CubicMeshPDENonStatio",
14
14
  "DataGeneratorParameter",
15
+ "DGParams",
15
16
  "DataGeneratorObservations",
16
17
  "ODEBatch",
17
18
  "PDEStatioBatch",