jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +7 -7
  2. jinns/data/_AbstractDataGenerator.py +1 -1
  3. jinns/data/_Batchs.py +47 -13
  4. jinns/data/_CubicMeshPDENonStatio.py +203 -54
  5. jinns/data/_CubicMeshPDEStatio.py +190 -54
  6. jinns/data/_DataGeneratorODE.py +48 -22
  7. jinns/data/_DataGeneratorObservations.py +75 -32
  8. jinns/data/_DataGeneratorParameter.py +152 -101
  9. jinns/data/__init__.py +2 -1
  10. jinns/data/_utils.py +22 -10
  11. jinns/loss/_DynamicLoss.py +21 -20
  12. jinns/loss/_DynamicLossAbstract.py +51 -36
  13. jinns/loss/_LossODE.py +210 -191
  14. jinns/loss/_LossPDE.py +441 -368
  15. jinns/loss/_abstract_loss.py +60 -25
  16. jinns/loss/_loss_components.py +4 -25
  17. jinns/loss/_loss_utils.py +23 -0
  18. jinns/loss/_loss_weight_updates.py +6 -7
  19. jinns/loss/_loss_weights.py +34 -35
  20. jinns/nn/_abstract_pinn.py +0 -2
  21. jinns/nn/_hyperpinn.py +34 -23
  22. jinns/nn/_mlp.py +5 -4
  23. jinns/nn/_pinn.py +1 -16
  24. jinns/nn/_ppinn.py +5 -16
  25. jinns/nn/_save_load.py +11 -4
  26. jinns/nn/_spinn.py +1 -16
  27. jinns/nn/_spinn_mlp.py +5 -5
  28. jinns/nn/_utils.py +33 -38
  29. jinns/parameters/__init__.py +3 -1
  30. jinns/parameters/_derivative_keys.py +99 -41
  31. jinns/parameters/_params.py +58 -25
  32. jinns/solver/_solve.py +14 -8
  33. jinns/utils/_DictToModuleMeta.py +66 -0
  34. jinns/utils/_ItemizableModule.py +19 -0
  35. jinns/utils/__init__.py +2 -1
  36. jinns/utils/_types.py +25 -15
  37. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  38. jinns-1.6.0.dist-info/RECORD +57 -0
  39. jinns-1.5.0.dist-info/RECORD +0 -55
  40. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  41. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  42. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,23 @@ from __future__ import (
8
8
  import equinox as eqx
9
9
  import jax
10
10
  import jax.numpy as jnp
11
- from jaxtyping import Key, Int, Array, Float
11
+ from typing import TYPE_CHECKING, Self
12
+ from jaxtyping import PRNGKeyArray, Int, Array, Float
12
13
  from jinns.data._Batchs import ObsBatchDict
13
14
  from jinns.data._utils import _reset_or_increment
14
15
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
16
+ from jinns.parameters._params import EqParams
17
+
18
+ if TYPE_CHECKING:
19
+ # imports only used in type hints
20
+ InputEqParams = (
21
+ dict[str, Float[Array, " n_obs"]] | dict[str, Float[Array, " n_obs 1"]]
22
+ ) | None
23
+
24
+ # Note that the lambda functions used below are with type: ignore just
25
+ # because the lambda are not type annotated, but there is no proper way
26
+ # to do this and we should assign the lambda to a type hinted variable
27
+ # before hand: this is not practical, let us not get mad at this
15
28
 
16
29
 
17
30
  class DataGeneratorObservations(AbstractDataGenerator):
@@ -21,7 +34,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
21
34
 
22
35
  Parameters
23
36
  ----------
24
- key : Key
37
+ key : PRNGKeyArray
25
38
  Jax random key to shuffle batches
26
39
  obs_batch_size : int | None
27
40
  The size of the batch of randomly selected points among
@@ -39,7 +52,8 @@ class DataGeneratorObservations(AbstractDataGenerator):
39
52
  observed_eq_params : dict[str, Float[Array, " n_obs 1"]], default={}
40
53
  A dict with keys corresponding to
41
54
  the parameter name. The keys must match the keys in
42
- `params["eq_params"]`. The values are jnp.array with 2 dimensions
55
+ `params["eq_params"]`, ie., if only some parameters are observed, other
56
+ keys **must still appear with None as value**. The values are jnp.array with 2 dimensions
43
57
  with values corresponding to the parameter value for which we also
44
58
  have observed_pinn_in and observed_values. Hence the first
45
59
  dimension must be aligned with observed_pinn_in and observed_values.
@@ -54,30 +68,37 @@ class DataGeneratorObservations(AbstractDataGenerator):
54
68
  arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
55
69
  """
56
70
 
57
- key: Key
71
+ key: PRNGKeyArray
58
72
  obs_batch_size: int | None = eqx.field(static=True)
59
73
  observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
60
74
  observed_values: Float[Array, " n_obs nb_pinn_out"]
61
- observed_eq_params: dict[str, Float[Array, " n_obs 1"]] = eqx.field(
62
- static=True, default_factory=lambda: {}
63
- )
64
- sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
75
+ observed_eq_params: eqx.Module | None
76
+ sharding_device: jax.sharding.Sharding | None # = eqx.field(static=True)
65
77
 
66
78
  n: int = eqx.field(init=False, static=True)
67
79
  curr_idx: int = eqx.field(init=False)
68
80
  indices: Array = eqx.field(init=False)
69
81
 
70
- def __post_init__(self):
82
+ def __init__(
83
+ self,
84
+ *,
85
+ key: PRNGKeyArray,
86
+ obs_batch_size: int | None = None,
87
+ observed_pinn_in: Float[Array, " n_obs nb_pinn_in"],
88
+ observed_values: Float[Array, " n_obs nb_pinn_out"],
89
+ observed_eq_params: InputEqParams | None = None,
90
+ sharding_device: jax.sharding.Sharding | None = None,
91
+ ) -> None:
92
+ super().__init__()
93
+ self.key = key
94
+ self.obs_batch_size = obs_batch_size
95
+ self.observed_pinn_in = observed_pinn_in
96
+ self.observed_values = observed_values
97
+
71
98
  if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
72
99
  raise ValueError(
73
100
  "self.observed_pinn_in and self.observed_values must have same first axis"
74
101
  )
75
- for _, v in self.observed_eq_params.items():
76
- if v.shape[0] != self.observed_pinn_in.shape[0]:
77
- raise ValueError(
78
- "self.observed_pinn_in and the values of"
79
- " self.observed_eq_params must have the same first axis"
80
- )
81
102
  if len(self.observed_pinn_in.shape) == 1:
82
103
  self.observed_pinn_in = self.observed_pinn_in[:, None]
83
104
  if self.observed_pinn_in.ndim > 2:
@@ -86,16 +107,32 @@ class DataGeneratorObservations(AbstractDataGenerator):
86
107
  self.observed_values = self.observed_values[:, None]
87
108
  if self.observed_values.ndim > 2:
88
109
  raise ValueError("self.observed_values must have 2 dimensions")
89
- for k, v in self.observed_eq_params.items():
90
- if len(v.shape) == 1:
91
- self.observed_eq_params[k] = v[:, None]
92
- if len(v.shape) > 2:
93
- raise ValueError(
94
- "Each value of observed_eq_params must have 2 dimensions"
95
- )
110
+
111
+ if observed_eq_params is not None:
112
+ for _, v in observed_eq_params.items():
113
+ if v.shape[0] != self.observed_pinn_in.shape[0]:
114
+ raise ValueError(
115
+ "self.observed_pinn_in and the values of"
116
+ " self.observed_eq_params must have the same first axis"
117
+ )
118
+ for k, v in observed_eq_params.items():
119
+ if len(v.shape) == 1:
120
+ # Reshape to add an axis for 1-d Array
121
+ observed_eq_params[k] = v[:, None]
122
+ if len(v.shape) > 2:
123
+ raise ValueError(
124
+ f"Each key of observed_eq_params must have 2"
125
+ f"dimensions, key {k} had shape {v.shape}."
126
+ )
127
+ # Convert the dict of observed parameters to the internal `EqParams`
128
+ # class used by Jinns.
129
+ self.observed_eq_params = EqParams(observed_eq_params, "EqParams")
130
+ else:
131
+ self.observed_eq_params = observed_eq_params
96
132
 
97
133
  self.n = self.observed_pinn_in.shape[0]
98
134
 
135
+ self.sharding_device = sharding_device
99
136
  if self.sharding_device is not None:
100
137
  self.observed_pinn_in = jax.lax.with_sharding_constraint(
101
138
  self.observed_pinn_in, self.sharding_device
@@ -126,7 +163,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
126
163
  self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
127
164
  # the call to _reset_batch_idx_and_permute in legacy DG
128
165
 
129
- def _get_operands(self) -> tuple[Key, Int[Array, " n"], int, int | None, None]:
166
+ def _get_operands(
167
+ self,
168
+ ) -> tuple[PRNGKeyArray, Int[Array, " n"], int, int | None, None]:
130
169
  return (
131
170
  self.key,
132
171
  self.indices,
@@ -137,17 +176,19 @@ class DataGeneratorObservations(AbstractDataGenerator):
137
176
 
138
177
  def obs_batch(
139
178
  self,
140
- ) -> tuple[DataGeneratorObservations, ObsBatchDict]:
179
+ ) -> tuple[Self, ObsBatchDict]:
141
180
  """
142
181
  Return an update DataGeneratorObservations instance and an ObsBatchDict
143
182
  """
144
183
  if self.obs_batch_size is None or self.obs_batch_size == self.n:
145
184
  # Avoid unnecessary reshuffling
146
- return self, {
147
- "pinn_in": self.observed_pinn_in,
148
- "val": self.observed_values,
149
- "eq_params": self.observed_eq_params,
150
- }
185
+ return self, ObsBatchDict(
186
+ {
187
+ "pinn_in": self.observed_pinn_in,
188
+ "val": self.observed_values,
189
+ "eq_params": self.observed_eq_params,
190
+ }
191
+ )
151
192
 
152
193
  new_attributes = _reset_or_increment(
153
194
  self.curr_idx + self.obs_batch_size,
@@ -157,7 +198,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
157
198
  # handled above
158
199
  )
159
200
  new = eqx.tree_at(
160
- lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
201
+ lambda m: (m.key, m.indices, m.curr_idx), # type: ignore
202
+ self,
203
+ new_attributes,
161
204
  )
162
205
 
163
206
  minib_indices = jax.lax.dynamic_slice(
@@ -174,7 +217,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
174
217
  new.observed_values, minib_indices, unique_indices=True, axis=0
175
218
  ),
176
219
  "eq_params": jax.tree_util.tree_map(
177
- lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
220
+ lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0), # type: ignore
178
221
  new.observed_eq_params,
179
222
  ),
180
223
  }
@@ -182,7 +225,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
182
225
 
183
226
  def get_batch(
184
227
  self,
185
- ) -> tuple[DataGeneratorObservations, ObsBatchDict]:
228
+ ) -> tuple[Self, ObsBatchDict]:
186
229
  """
187
230
  Generic method to return a batch
188
231
  """
@@ -5,12 +5,23 @@ Define the DataGenerators modules
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
+ from typing import Self
8
9
  import equinox as eqx
9
10
  import jax
10
11
  import jax.numpy as jnp
11
- from jaxtyping import Key, Array, Float
12
+ from jaxtyping import PRNGKeyArray, Array, Float
12
13
  from jinns.data._utils import _reset_or_increment
13
14
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
15
+ from jinns.utils._DictToModuleMeta import DictToModuleMeta
16
+
17
+
18
+ class DGParams(metaclass=DictToModuleMeta):
19
+ """
20
+ However, static type checkers cannot know that DGParams inherit from
21
+ eqx.Module and explicit casting to the latter class will be needed
22
+ """
23
+
24
+ pass
14
25
 
15
26
 
16
27
  class DataGeneratorParameter(AbstractDataGenerator):
@@ -21,9 +32,8 @@ class DataGeneratorParameter(AbstractDataGenerator):
21
32
 
22
33
  Parameters
23
34
  ----------
24
- keys : Key | dict[str, Key]
25
- Jax random key to sample new time points and to shuffle batches
26
- or a dict of Jax random keys with key entries from param_ranges
35
+ key : PRNGKeyArray
36
+ Jax random key to sample new time points and to shuffle batches.
27
37
  n : int
28
38
  The number of total points that will be divided in
29
39
  batches. Batches are made so that each data point is seen only
@@ -58,71 +68,86 @@ class DataGeneratorParameter(AbstractDataGenerator):
58
68
  Defaults to None.
59
69
  """
60
70
 
61
- keys: Key | dict[str, Key]
71
+ key: PRNGKeyArray
62
72
  n: int = eqx.field(static=True)
63
- param_batch_size: int | None = eqx.field(static=True, default=None)
64
- param_ranges: dict[str, tuple[Float, Float]] = eqx.field(
65
- static=True, default_factory=lambda: {}
66
- )
67
- method: str = eqx.field(static=True, default="uniform")
68
- user_data: dict[str, Float[Array, " n"]] | None = eqx.field(
69
- default_factory=lambda: {}
70
- )
73
+ param_batch_size: int | None = eqx.field(static=True)
74
+ param_ranges: dict[str, tuple[Float, Float]] = eqx.field(static=True)
75
+ method: str = eqx.field(static=True)
76
+ user_data: dict[str, Float[Array, " n"]]
77
+
78
+ # --- Below fields are not passed as arguments to __init__
79
+ _all_params_keys: set[str] = eqx.field(init=False, static=True)
80
+ curr_param_idx: eqx.Module | None = eqx.field(init=False)
81
+ param_n_samples: eqx.Module = eqx.field(init=False)
71
82
 
72
- curr_param_idx: dict[str, int] = eqx.field(init=False)
73
- param_n_samples: dict[str, Array] = eqx.field(init=False)
83
+ def __init__(
84
+ self,
85
+ *,
86
+ key: PRNGKeyArray,
87
+ n: int,
88
+ param_batch_size: int | None,
89
+ param_ranges: dict[str, tuple[Float, Float]] = {},
90
+ method: str = "uniform",
91
+ user_data: dict[str, Float[Array, " n"]] = {},
92
+ ):
93
+ self.key = key
94
+ self.n = n
95
+ self.param_batch_size = param_batch_size
96
+ self.param_ranges = param_ranges
97
+ self.method = method
98
+ self.user_data = user_data
99
+
100
+ _all_keys = set().union(self.param_ranges, self.user_data)
101
+ self._all_params_keys = _all_keys
74
102
 
75
- def __post_init__(self):
76
- if self.user_data is None:
77
- self.user_data = {}
78
- if self.param_ranges is None:
79
- self.param_ranges = {}
80
103
  if self.param_batch_size is not None and self.n < self.param_batch_size:
81
104
  raise ValueError(
82
105
  f"Number of data points ({self.n}) is smaller than the"
83
106
  f"number of batch points ({self.param_batch_size})."
84
107
  )
85
- if not isinstance(self.keys, dict):
86
- all_keys = set().union(self.param_ranges, self.user_data)
87
- self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
108
+
109
+ # NOTE from jinns > v1.5.1 we work with eqx.Module
110
+ # because eq_params is not a dict anymore.
111
+ # We have to use a different class from the publicly exposed EqParams
112
+ # because fields(EqParams) are not necessarily all present in the
113
+ # datagenerator, which would cause eqx.Module to error.
114
+
115
+ # 1) Call self.generate_data() to generate a dictionnary that merges the scattered data between `user_data` and `param_ranges`
116
+ self.key, _param_n_samples = self.generate_data(self.key)
117
+
118
+ # 2) Use the dictionnary to populate the field of the eqx.Module.
119
+ self.param_n_samples = DGParams(_param_n_samples, "DGParams")
88
120
 
89
121
  if self.param_batch_size is None:
90
- self.curr_param_idx = None # type: ignore
122
+ self.curr_param_idx = None
91
123
  else:
92
- self.curr_param_idx = {}
93
- for k in self.keys.keys():
94
- self.curr_param_idx[k] = self.n + self.param_batch_size
95
- # to be sure there is a shuffling at first get_batch()
124
+ curr_idx = self.n + self.param_batch_size
125
+ param_keys_and_curr_idx = {k: curr_idx for k in self._all_params_keys}
96
126
 
97
- # The call to self.generate_data() creates
98
- # the dict self.param_n_samples and then we will only use this one
99
- # because it merges the scattered data between `user_data` and
100
- # `param_ranges`
101
- self.keys, self.param_n_samples = self.generate_data(self.keys)
127
+ self.curr_param_idx = DGParams(param_keys_and_curr_idx)
102
128
 
103
129
  def generate_data(
104
- self, keys: dict[str, Key]
105
- ) -> tuple[dict[str, Key], dict[str, Float[Array, " n"]]]:
130
+ self, key: PRNGKeyArray
131
+ ) -> tuple[PRNGKeyArray, dict[str, Float[Array, " n 1"]]]:
106
132
  """
107
133
  Generate parameter samples, either through generation
108
134
  or using user-provided data.
109
135
  """
110
136
  param_n_samples = {}
111
137
 
112
- all_keys = set().union(
113
- self.param_ranges,
114
- self.user_data, # type: ignore this has been handled in post_init
115
- )
116
- for k in all_keys:
138
+ # Some of the subkeys might not be used cause of user-provided data.
139
+ # This is not a big deal and simpler like that.
140
+ key, *subkeys = jax.random.split(key, len(self._all_params_keys) + 1)
141
+ for i, k in enumerate(self._all_params_keys):
117
142
  if self.user_data and k in self.user_data.keys():
118
- if self.user_data[k].shape == (self.n, 1):
119
- param_n_samples[k] = self.user_data[k]
120
- if self.user_data[k].shape == (self.n,):
121
- param_n_samples[k] = self.user_data[k][:, None]
122
- else:
123
- raise ValueError(
143
+ try:
144
+ param_n_samples[k] = self.user_data[k].reshape((self.n, 1))
145
+ except TypeError:
146
+ shape = self.user_data[k].shape
147
+ raise TypeError(
124
148
  "Wrong shape for user provided parameters"
125
- f" in user_data dictionary at key='{k}'"
149
+ f" in user_data dictionary at key='{k}' got {shape} "
150
+ f"and expected {(self.n, 1)}."
126
151
  )
127
152
  else:
128
153
  if self.method == "grid":
@@ -132,75 +157,101 @@ class DataGeneratorParameter(AbstractDataGenerator):
132
157
  param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
133
158
  elif self.method == "uniform":
134
159
  xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
135
- keys[k], subkey = jax.random.split(keys[k], 2)
136
160
  param_n_samples[k] = jax.random.uniform(
137
- subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
161
+ subkeys[i], shape=(self.n, 1), minval=xmin, maxval=xmax
138
162
  )
139
163
  else:
140
164
  raise ValueError("Method " + self.method + " is not implemented.")
141
165
 
142
- return keys, param_n_samples
143
-
144
- def _get_param_operands(
145
- self, k: str
146
- ) -> tuple[Key, Float[Array, " n"], int, int | None, None]:
147
- return (
148
- self.keys[k],
149
- self.param_n_samples[k],
150
- self.curr_param_idx[k],
151
- self.param_batch_size,
152
- None,
153
- )
166
+ return key, param_n_samples
154
167
 
155
- def param_batch(self):
168
+ def param_batch(self) -> tuple[Self, eqx.Module]:
156
169
  """
157
- Return a dictionary with batches of parameters
158
- If all the batches have been seen, we reshuffle them,
159
- otherwise we just return the next unseen batch.
170
+ Return an `eqx.Module` with batches of parameters at its leafs.
171
+ If all the batches have been seen, we reshuffle them (or rather
172
+ their indices), otherwise we just return the next unseen batch.
160
173
  """
161
174
 
162
175
  if self.param_batch_size is None or self.param_batch_size == self.n:
176
+ # Full batch mode: nothing to do.
163
177
  return self, self.param_n_samples
178
+ else:
179
+
180
+ def _reset_or_increment_wrapper(
181
+ param_k: Array, idx_k: int, key_k: PRNGKeyArray
182
+ ):
183
+ everything_but_key = _reset_or_increment(
184
+ idx_k + self.param_batch_size, # type: ignore
185
+ self.n,
186
+ (key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
187
+ )[1:]
188
+ return everything_but_key
164
189
 
165
- def _reset_or_increment_wrapper(param_k, idx_k, key_k):
166
- return _reset_or_increment(
167
- idx_k + self.param_batch_size,
168
- self.n,
169
- (key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
170
- # ignore since the case self.param_batch_size is None has been
171
- # handled above
190
+ new_key, *subkeys = jax.random.split(
191
+ self.key, len(self._all_params_keys) + 1
172
192
  )
173
193
 
174
- res = jax.tree_util.tree_map(
175
- _reset_or_increment_wrapper,
176
- self.param_n_samples,
177
- self.curr_param_idx,
178
- self.keys,
179
- )
180
- # we must transpose the pytrees because keys are merged in res
181
- # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
182
- new_attributes = jax.tree_util.tree_transpose(
183
- jax.tree_util.tree_structure(self.keys),
184
- jax.tree_util.tree_structure([0, 0, 0]),
185
- res,
186
- )
187
-
188
- new = eqx.tree_at(
189
- lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
190
- self,
191
- new_attributes,
192
- )
193
-
194
- return new, jax.tree_util.tree_map(
195
- lambda p, q: jax.lax.dynamic_slice(
196
- p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
197
- ),
198
- new.param_n_samples,
199
- new.curr_param_idx,
200
- )
201
-
202
- def get_batch(self):
194
+ # From PRNGKeyArray to a pytree of keys with adequate structure
195
+ subkeys = jax.tree.unflatten(
196
+ jax.tree.structure(self.param_n_samples), subkeys
197
+ )
198
+
199
+ res = jax.tree.map(
200
+ _reset_or_increment_wrapper,
201
+ self.param_n_samples,
202
+ self.curr_param_idx,
203
+ subkeys,
204
+ )
205
+ # we must transpose the pytrees because both params and curr_idx # are merged in res
206
+ # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
207
+
208
+ new_attributes = jax.tree.transpose(
209
+ jax.tree.structure(self.param_n_samples),
210
+ jax.tree.structure([0, 0]),
211
+ res,
212
+ )
213
+
214
+ new = eqx.tree_at(
215
+ lambda m: (m.param_n_samples, m.curr_param_idx),
216
+ self,
217
+ new_attributes,
218
+ )
219
+
220
+ new = eqx.tree_at(lambda m: m.key, new, new_key)
221
+
222
+ return new, jax.tree_util.tree_map(
223
+ lambda p, q: jax.lax.dynamic_slice(
224
+ p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
225
+ ),
226
+ new.param_n_samples,
227
+ new.curr_param_idx,
228
+ )
229
+
230
+ def get_batch(self) -> tuple[Self, eqx.Module]:
203
231
  """
204
232
  Generic method to return a batch
205
233
  """
206
234
  return self.param_batch()
235
+
236
+
237
+ if __name__ == "__main__":
238
+ key = jax.random.PRNGKey(2)
239
+ key, subkey = jax.random.split(key)
240
+
241
+ n = 64
242
+ param_batch_size = 32
243
+ method = "uniform"
244
+ param_ranges = {"theta": (10.0, 11.0)}
245
+ user_data = {"nu": jnp.ones((n, 1))}
246
+
247
+ x = DataGeneratorParameter(
248
+ key=subkey,
249
+ n=n,
250
+ param_batch_size=param_batch_size,
251
+ param_ranges=param_ranges,
252
+ method=method,
253
+ user_data=user_data,
254
+ )
255
+ print(x.key)
256
+ x, batch = x.get_batch()
257
+ print(x.key)
jinns/data/__init__.py CHANGED
@@ -2,7 +2,7 @@ from ._DataGeneratorODE import DataGeneratorODE
2
2
  from ._CubicMeshPDEStatio import CubicMeshPDEStatio
3
3
  from ._CubicMeshPDENonStatio import CubicMeshPDENonStatio
4
4
  from ._DataGeneratorObservations import DataGeneratorObservations
5
- from ._DataGeneratorParameter import DataGeneratorParameter
5
+ from ._DataGeneratorParameter import DataGeneratorParameter, DGParams
6
6
  from ._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
7
7
 
8
8
  from ._utils import append_obs_batch, append_param_batch
@@ -12,6 +12,7 @@ __all__ = [
12
12
  "CubicMeshPDEStatio",
13
13
  "CubicMeshPDENonStatio",
14
14
  "DataGeneratorParameter",
15
+ "DGParams",
15
16
  "DataGeneratorObservations",
16
17
  "ODEBatch",
17
18
  "PDEStatioBatch",
jinns/data/_utils.py CHANGED
@@ -3,19 +3,19 @@ Utility functions for DataGenerators
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
-
6
+ import warnings
7
7
  from typing import TYPE_CHECKING
8
8
  import equinox as eqx
9
9
  import jax
10
10
  import jax.numpy as jnp
11
- from jaxtyping import Key, Array, Float
11
+ from jaxtyping import PRNGKeyArray, Array, Float
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from jinns.utils._types import AnyBatch
15
15
  from jinns.data._Batchs import ObsBatchDict
16
16
 
17
17
 
18
- def append_param_batch(batch: AnyBatch, param_batch_dict: dict[str, Array]) -> AnyBatch:
18
+ def append_param_batch(batch: AnyBatch, param_batch_dict: eqx.Module) -> AnyBatch:
19
19
  """
20
20
  Utility function that fills the field `batch.param_batch_dict` of a batch object.
21
21
  """
@@ -53,8 +53,10 @@ def make_cartesian_product(
53
53
 
54
54
 
55
55
  def _reset_batch_idx_and_permute(
56
- operands: tuple[Key, Float[Array, " n dimension"], int, None, Float[Array, " n"]],
57
- ) -> tuple[Key, Float[Array, " n dimension"], int]:
56
+ operands: tuple[
57
+ PRNGKeyArray, Float[Array, " n dimension"], int, None, Float[Array, " n"] | None
58
+ ],
59
+ ) -> tuple[PRNGKeyArray, Float[Array, " n dimension"], int]:
58
60
  key, domain, curr_idx, _, p = operands
59
61
  # resetting counter
60
62
  curr_idx = 0
@@ -77,8 +79,10 @@ def _reset_batch_idx_and_permute(
77
79
 
78
80
 
79
81
  def _increment_batch_idx(
80
- operands: tuple[Key, Float[Array, " n dimension"], int, int, Float[Array, " n"]],
81
- ) -> tuple[Key, Float[Array, " n dimension"], int]:
82
+ operands: tuple[
83
+ PRNGKeyArray, Float[Array, " n dimension"], int, int, Float[Array, " n"] | None
84
+ ],
85
+ ) -> tuple[PRNGKeyArray, Float[Array, " n dimension"], int]:
82
86
  key, domain, curr_idx, batch_size, _ = operands
83
87
  # simply increases counter and get the batch
84
88
  curr_idx += batch_size
@@ -88,8 +92,10 @@ def _increment_batch_idx(
88
92
  def _reset_or_increment(
89
93
  bend: int,
90
94
  n_eff: int,
91
- operands: tuple[Key, Float[Array, " n dimension"], int, int, Float[Array, " n"]],
92
- ) -> tuple[Key, Float[Array, " n dimension"], int]:
95
+ operands: tuple[
96
+ PRNGKeyArray, Float[Array, " n dimension"], int, int, Float[Array, " n"] | None
97
+ ],
98
+ ) -> tuple[PRNGKeyArray, Float[Array, " n dimension"], int]:
93
99
  """
94
100
  Factorize the code of the jax.lax.cond which checks if we have seen all the
95
101
  batches in an epoch
@@ -119,7 +125,7 @@ def _reset_or_increment(
119
125
 
120
126
 
121
127
  def _check_and_set_rar_parameters(
122
- rar_parameters: dict, n: int, n_start: int
128
+ rar_parameters: None | dict, n: int, n_start: None | int
123
129
  ) -> tuple[int, Float[Array, " n"] | None, int | None, int | None]:
124
130
  if rar_parameters is not None and n_start is None:
125
131
  raise ValueError(
@@ -127,6 +133,12 @@ def _check_and_set_rar_parameters(
127
133
  )
128
134
 
129
135
  if rar_parameters is not None:
136
+ if n_start is None:
137
+ n_start = 0
138
+ warnings.warn(
139
+ "You asked for RAR sampling but didn't provide"
140
+ f"a proper `n_start` {n_start=}. Setting it to 0."
141
+ )
130
142
  # Default p is None. However, in the RAR sampling scheme we use 0
131
143
  # probability to specify non-used collocation points (i.e. points
132
144
  # above n_start). Thus, p is a vector of probability of shape (nt, 1).