jinns 1.5.1__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. jinns/data/_AbstractDataGenerator.py +1 -1
  2. jinns/data/_Batchs.py +47 -13
  3. jinns/data/_CubicMeshPDENonStatio.py +55 -34
  4. jinns/data/_CubicMeshPDEStatio.py +63 -35
  5. jinns/data/_DataGeneratorODE.py +48 -22
  6. jinns/data/_DataGeneratorObservations.py +86 -32
  7. jinns/data/_DataGeneratorParameter.py +152 -101
  8. jinns/data/__init__.py +2 -1
  9. jinns/data/_utils.py +22 -10
  10. jinns/loss/_DynamicLoss.py +21 -20
  11. jinns/loss/_DynamicLossAbstract.py +51 -36
  12. jinns/loss/_LossODE.py +139 -184
  13. jinns/loss/_LossPDE.py +440 -358
  14. jinns/loss/_abstract_loss.py +60 -25
  15. jinns/loss/_loss_components.py +4 -25
  16. jinns/loss/_loss_weight_updates.py +6 -7
  17. jinns/loss/_loss_weights.py +34 -35
  18. jinns/nn/_abstract_pinn.py +0 -2
  19. jinns/nn/_hyperpinn.py +34 -23
  20. jinns/nn/_mlp.py +5 -4
  21. jinns/nn/_pinn.py +1 -16
  22. jinns/nn/_ppinn.py +5 -16
  23. jinns/nn/_save_load.py +11 -4
  24. jinns/nn/_spinn.py +1 -16
  25. jinns/nn/_spinn_mlp.py +5 -5
  26. jinns/nn/_utils.py +33 -38
  27. jinns/parameters/__init__.py +3 -1
  28. jinns/parameters/_derivative_keys.py +99 -41
  29. jinns/parameters/_params.py +50 -25
  30. jinns/solver/_solve.py +3 -3
  31. jinns/utils/_DictToModuleMeta.py +66 -0
  32. jinns/utils/_ItemizableModule.py +19 -0
  33. jinns/utils/__init__.py +2 -1
  34. jinns/utils/_types.py +25 -15
  35. {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/METADATA +2 -2
  36. jinns-1.6.1.dist-info/RECORD +57 -0
  37. jinns-1.5.1.dist-info/RECORD +0 -55
  38. {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/WHEEL +0 -0
  39. {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/AUTHORS +0 -0
  40. {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/LICENSE +0 -0
  41. {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/top_level.txt +0 -0
@@ -15,5 +15,5 @@ class AbstractDataGenerator(eqx.Module):
15
15
  """
16
16
 
17
17
  @abc.abstractmethod
18
- def get_batch(self) -> tuple[type[Self], AnyBatch]: # type: ignore
18
+ def get_batch(self) -> tuple[Self, AnyBatch]:
19
19
  pass
jinns/data/_Batchs.py CHANGED
@@ -18,25 +18,59 @@ class ObsBatchDict(TypedDict):
18
18
 
19
19
  pinn_in: Float[Array, " obs_batch_size input_dim"]
20
20
  val: Float[Array, " obs_batch_size output_dim"]
21
- eq_params: dict[str, Float[Array, " obs_batch_size 1"]]
21
+ eq_params: (
22
+ eqx.Module | None
23
+ ) # None cause sometime user don't provide observed params
22
24
 
23
25
 
24
26
  class ODEBatch(eqx.Module):
25
27
  temporal_batch: Float[Array, " batch_size"]
26
- param_batch_dict: dict[str, Array] = eqx.field(default=None)
27
- obs_batch_dict: ObsBatchDict = eqx.field(default=None)
28
-
29
-
30
- class PDENonStatioBatch(eqx.Module):
31
- domain_batch: Float[Array, " batch_size 1+dimension"]
32
- border_batch: Float[Array, " batch_size dimension n_facets"] | None
33
- initial_batch: Float[Array, " batch_size dimension"] | None
34
- param_batch_dict: dict[str, Array] = eqx.field(default=None)
35
- obs_batch_dict: ObsBatchDict = eqx.field(default=None)
28
+ param_batch_dict: eqx.Module | None = eqx.field(default=None)
29
+ obs_batch_dict: ObsBatchDict | None = eqx.field(default=None)
36
30
 
37
31
 
38
32
  class PDEStatioBatch(eqx.Module):
39
33
  domain_batch: Float[Array, " batch_size dimension"]
40
34
  border_batch: Float[Array, " batch_size dimension n_facets"] | None
41
- param_batch_dict: dict[str, Array] = eqx.field(default=None)
42
- obs_batch_dict: ObsBatchDict = eqx.field(default=None)
35
+ param_batch_dict: eqx.Module | None
36
+ obs_batch_dict: ObsBatchDict | None
37
+
38
+ # rewrite __init__ to be able to use inheritance for the NonStatio case
39
+ # below. That way PDENonStatioBatch is a subtype of PDEStatioBatch which
40
+ # 1) makes more sense and 2) CubicMeshPDENonStatio.get_batch passes pyright.
41
+ def __init__(
42
+ self,
43
+ *,
44
+ domain_batch: Float[Array, " batch_size dimension"],
45
+ border_batch: Float[Array, " batch_size dimension n_facets"] | None,
46
+ param_batch_dict: eqx.Module | None = None,
47
+ obs_batch_dict: ObsBatchDict | None = None,
48
+ ):
49
+ # TODO: document this ?
50
+ self.domain_batch = domain_batch
51
+ self.border_batch = border_batch
52
+ self.param_batch_dict = param_batch_dict
53
+ self.obs_batch_dict = obs_batch_dict
54
+
55
+
56
+ class PDENonStatioBatch(PDEStatioBatch):
57
+ # TODO: document this ?
58
+ domain_batch: Float[Array, " batch_size 1+dimension"] # Override type
59
+ initial_batch: (
60
+ Float[Array, " batch_size dimension"] | None
61
+ ) # why can it be None ? Examples?
62
+
63
+ def __init__(
64
+ self,
65
+ *,
66
+ domain_batch: Float[Array, " batch_size 1+dimension"],
67
+ border_batch: Float[Array, " batch_size dimension n_facets"] | None,
68
+ initial_batch: Float[Array, " batch_size dimension"] | None,
69
+ param_batch_dict: eqx.Module | None = None,
70
+ obs_batch_dict: ObsBatchDict | None = None,
71
+ ):
72
+ self.domain_batch = domain_batch
73
+ self.border_batch = border_batch
74
+ self.initial_batch = initial_batch
75
+ self.param_batch_dict = param_batch_dict
76
+ self.obs_batch_dict = obs_batch_dict
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import jax
12
12
  import jax.numpy as jnp
13
13
  from scipy.stats import qmc
14
- from jaxtyping import Key, Array, Float
14
+ from jaxtyping import PRNGKeyArray, Array, Float
15
15
  from jinns.data._Batchs import PDENonStatioBatch
16
16
  from jinns.data._utils import (
17
17
  make_cartesian_product,
@@ -29,7 +29,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
29
29
 
30
30
  Parameters
31
31
  ----------
32
- key : Key
32
+ key : PRNGKeyArray
33
33
  Jax random key to sample new time points and to shuffle batches
34
34
  n : int
35
35
  The number of total $I\times \Omega$ points that will be divided in
@@ -50,9 +50,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
50
50
  among the `nb` points. If None, `domain_batch_size` no
51
51
  mini-batches are used.
52
52
  initial_batch_size : int | None, default=None
53
- The size of the batch of randomly selected points among
54
- the `ni` points. If None no
55
- mini-batches are used.
53
+ The number of randomly selected points among the `ni` initial spatial
54
+ points used for initial condition. If None, no mini-batches are used.
56
55
  dim : int
57
56
  An integer. Dimension of $\Omega$ domain.
58
57
  min_pts : tuple[tuple[Float, Float], ...]
@@ -94,13 +93,14 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
94
93
  then corresponds to the initial number of omega points we train the PINN.
95
94
  """
96
95
 
97
- tmin: Float = eqx.field(kw_only=True)
98
- tmax: Float = eqx.field(kw_only=True)
99
- ni: int = eqx.field(kw_only=True, static=True)
100
- domain_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
101
- initial_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
102
- border_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
96
+ tmin: float
97
+ tmax: float
98
+ ni: int = eqx.field(static=True)
99
+ domain_batch_size: int | None = eqx.field(static=True)
100
+ initial_batch_size: int | None = eqx.field(static=True)
101
+ border_batch_size: int | None = eqx.field(static=True)
103
102
 
103
+ # --- Below fields are not passed as arguments to __init__
104
104
  curr_domain_idx: int = eqx.field(init=False)
105
105
  curr_initial_idx: int = eqx.field(init=False)
106
106
  curr_border_idx: int = eqx.field(init=False)
@@ -110,13 +110,32 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
110
110
  )
111
111
  initial: Float[Array, " ni dim"] | None = eqx.field(init=False)
112
112
 
113
- def __post_init__(self):
113
+ def __init__(
114
+ self,
115
+ tmin: float,
116
+ tmax: float,
117
+ ni: int,
118
+ domain_batch_size: int | None = None,
119
+ initial_batch_size: int | None = None,
120
+ border_batch_size: int | None = None,
121
+ **kwargs, # kwargs for CubicMeshPDEStatio.__init__
122
+ ):
114
123
  """
115
124
  Note that neither __init__ or __post_init__ are called when udating a
116
125
  Module with eqx.tree_at!
117
126
  """
118
- super().__post_init__() # because __init__ or __post_init__ of Base
119
- # class is not automatically called
127
+ # sanity check
128
+ if ni is None:
129
+ raise ValueError("`ni` cannot be None.")
130
+
131
+ super().__init__(**kwargs)
132
+ self.tmin = tmin
133
+ self.tmax = tmax
134
+ self.ni = ni
135
+
136
+ self.domain_batch_size = domain_batch_size
137
+ self.initial_batch_size = initial_batch_size
138
+ self.border_batch_size = border_batch_size
120
139
 
121
140
  if self.method == "grid":
122
141
  # NOTE we must redo the sampling with the square root number of samples
@@ -144,7 +163,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
144
163
  )
145
164
  self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
146
165
 
147
- # NOTE
166
+ # NOTE below re-do CubicMeshPDE.__init__() ? Maybe useless?
148
167
  (
149
168
  self.n_start,
150
169
  self.p,
@@ -178,7 +197,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
178
197
  " a multiple of 2xd (the # of faces of a d-dimensional cube)"
179
198
  )
180
199
  # the check below concern omega_border_batch_size for dim > 1 in
181
- # super.__post_init__. Here it concerns all dim values since our
200
+ # super.__init__. Here it concerns all dim values since our
182
201
  # border_batch is the concatenation or cartesian product with times
183
202
  if (
184
203
  self.border_batch_size is not None
@@ -221,7 +240,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
221
240
  self.border_batch_size = None
222
241
  self.curr_border_idx = 0
223
242
 
224
- if self.ni is not None:
243
+ if ni is not None:
225
244
  if self.method == "grid":
226
245
  perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
227
246
  if self.ni != perfect_sq:
@@ -235,17 +254,17 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
235
254
  log2_n = jnp.log2(self.ni)
236
255
  lower_pow = 2 ** jnp.floor(log2_n)
237
256
  higher_pow = 2 ** jnp.ceil(log2_n)
238
- closest_two_power = (
257
+ closest_power_of_two = (
239
258
  lower_pow
240
259
  if (self.ni - lower_pow) < (higher_pow - self.ni)
241
260
  else higher_pow
242
261
  )
243
- if self.n != closest_two_power:
262
+ if self.n != closest_power_of_two:
244
263
  warnings.warn(
245
264
  f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
246
- f"Modfiying self.n from {self.ni} to {closest_two_power}.",
265
+ f"Modfiying self.n from {self.ni} to {closest_power_of_two}.",
247
266
  )
248
- self.ni = int(closest_two_power)
267
+ self.ni = int(closest_power_of_two)
249
268
  self.key, self.initial = self.generate_omega_data(
250
269
  self.key, data_size=self.ni
251
270
  )
@@ -264,8 +283,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
264
283
  self.omega_border = None
265
284
 
266
285
  def generate_time_data(
267
- self, key: Key, nt: int
268
- ) -> tuple[Key, Float[Array, " nt 1"]]:
286
+ self, key: PRNGKeyArray, nt: int
287
+ ) -> tuple[PRNGKeyArray, Float[Array, " nt 1"]]:
269
288
  """
270
289
  Construct a complete set of `nt` time points according to the
271
290
  specified `self.method`
@@ -278,12 +297,14 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
278
297
  return key, self.sample_in_time_domain(subkey, nt)
279
298
  raise ValueError("Method " + self.method + " is not implemented.")
280
299
 
281
- def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
300
+ def sample_in_time_domain(
301
+ self, key: PRNGKeyArray, nt: int
302
+ ) -> Float[Array, " nt 1"]:
282
303
  return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
283
304
 
284
305
  def qmc_in_time_omega_domain(
285
- self, key: Key, sample_size: int
286
- ) -> tuple[Key, Float[Array, "n 1+dim"]]:
306
+ self, key: PRNGKeyArray, sample_size: int
307
+ ) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]]:
287
308
  """
288
309
  Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
289
310
  We generate time and omega samples jointly
@@ -300,8 +321,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
300
321
  return key, jnp.array(samples)
301
322
 
302
323
  def qmc_in_time_omega_border_domain(
303
- self, key: Key, sample_size: int | None = None
304
- ) -> tuple[Key, Float[Array, "n 1+dim"]] | None:
324
+ self, key: PRNGKeyArray, sample_size: int | None = None
325
+ ) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]] | None:
305
326
  """
306
327
  For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
307
328
 
@@ -387,7 +408,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
387
408
 
388
409
  def _get_domain_operands(
389
410
  self,
390
- ) -> tuple[Key, Float[Array, " n 1+dim"], int, int | None, Array | None]:
411
+ ) -> tuple[PRNGKeyArray, Float[Array, " n 1+dim"], int, int | None, Array | None]:
391
412
  return (
392
413
  self.key,
393
414
  self.domain,
@@ -424,7 +445,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
424
445
  # handled above
425
446
  )
426
447
  new = eqx.tree_at(
427
- lambda m: (m.key, m.domain, m.curr_domain_idx),
448
+ lambda m: (m.key, m.domain, m.curr_domain_idx), # type: ignore
428
449
  self,
429
450
  new_attributes,
430
451
  )
@@ -437,7 +458,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
437
458
  def _get_border_operands(
438
459
  self,
439
460
  ) -> tuple[
440
- Key,
461
+ PRNGKeyArray,
441
462
  Float[Array, " nb 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None,
442
463
  int,
443
464
  int | None,
@@ -483,7 +504,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
483
504
  # handled above
484
505
  )
485
506
  new = eqx.tree_at(
486
- lambda m: (m.key, m.border, m.curr_border_idx),
507
+ lambda m: (m.key, m.border, m.curr_border_idx), # type: ignore
487
508
  self,
488
509
  new_attributes,
489
510
  )
@@ -500,7 +521,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
500
521
 
501
522
  def _get_initial_operands(
502
523
  self,
503
- ) -> tuple[Key, Float[Array, " ni dim"] | None, int, int | None, None]:
524
+ ) -> tuple[PRNGKeyArray, Float[Array, " ni dim"] | None, int, int | None, None]:
504
525
  return (
505
526
  self.key,
506
527
  self.initial,
@@ -529,7 +550,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
529
550
  # handled above
530
551
  )
531
552
  new = eqx.tree_at(
532
- lambda m: (m.key, m.initial, m.curr_initial_idx),
553
+ lambda m: (m.key, m.initial, m.curr_initial_idx), # type: ignore
533
554
  self,
534
555
  new_attributes,
535
556
  )
@@ -11,7 +11,7 @@ import jax
11
11
  import numpy as np
12
12
  import jax.numpy as jnp
13
13
  from scipy.stats import qmc
14
- from jaxtyping import Key, Array, Float
14
+ from jaxtyping import PRNGKeyArray, Array, Float
15
15
  from typing import Literal
16
16
  from jinns.data._Batchs import PDEStatioBatch
17
17
  from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
@@ -25,7 +25,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
25
25
 
26
26
  Parameters
27
27
  ----------
28
- key : Key
28
+ key : PRNGKeyArray
29
29
  Jax random key to sample new time points and to shuffle batches
30
30
  n : int
31
31
  The number of total $\Omega$ points that will be divided in
@@ -80,32 +80,28 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
80
80
  then corresponds to the initial number of points we train the PINN on.
81
81
  """
82
82
 
83
- # kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
84
- key: Key = eqx.field(kw_only=True)
85
- n: int = eqx.field(kw_only=True, static=True)
83
+ key: PRNGKeyArray
84
+ n: int = eqx.field(static=True)
86
85
  nb: int | None = eqx.field(kw_only=True, static=True, default=None)
87
86
  omega_batch_size: int | None = eqx.field(
88
- kw_only=True,
89
87
  static=True,
90
- default=None, # can be None as
88
+ # can be None as
91
89
  # CubicMeshPDENonStatio inherits but also if omega_batch_size=n
92
90
  ) # static cause used as a
93
91
  # shape in jax.lax.dynamic_slice
94
92
  omega_border_batch_size: int | None = eqx.field(
95
- kw_only=True, static=True, default=None
93
+ static=True,
96
94
  ) # static cause used as a
97
95
  # shape in jax.lax.dynamic_slice
98
- dim: int = eqx.field(kw_only=True, static=True) # static cause used as a
96
+ dim: int = eqx.field(static=True) # static cause used as a
99
97
  # shape in jax.lax.dynamic_slice
100
- min_pts: tuple[float, ...] = eqx.field(kw_only=True)
101
- max_pts: tuple[float, ...] = eqx.field(kw_only=True)
102
- method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(
103
- kw_only=True, static=True, default_factory=lambda: "uniform"
104
- )
105
- rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
106
- n_start: int = eqx.field(kw_only=True, default=None, static=True)
98
+ min_pts: tuple[float, ...]
99
+ max_pts: tuple[float, ...]
100
+ method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(static=True)
101
+ rar_parameters: None | dict[str, int]
102
+ n_start: int = eqx.field(static=True)
107
103
 
108
- # all the init=False fields are set in __post_init__
104
+ # --- Below fields are not passed as arguments to __init__
109
105
  p: Float[Array, " n"] | None = eqx.field(init=False)
110
106
  rar_iter_from_last_sampling: int | None = eqx.field(init=False)
111
107
  rar_iter_nb: int | None = eqx.field(init=False)
@@ -116,7 +112,32 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
116
112
  eqx.field(init=False)
117
113
  )
118
114
 
119
- def __post_init__(self):
115
+ def __init__(
116
+ self,
117
+ *,
118
+ key: PRNGKeyArray,
119
+ n: int,
120
+ nb: int | None = None,
121
+ omega_batch_size: int | None = None,
122
+ omega_border_batch_size: int | None = None,
123
+ dim: int,
124
+ min_pts: tuple[float, ...],
125
+ max_pts: tuple[float, ...],
126
+ method: Literal["grid", "uniform", "sobol", "halton"] = "uniform",
127
+ rar_parameters: dict[str, int] | None = None,
128
+ n_start: int | None = None,
129
+ ):
130
+ self.key = key
131
+ self.n = n
132
+ self.nb = nb
133
+ self.omega_batch_size = omega_batch_size
134
+ self.omega_border_batch_size = omega_border_batch_size
135
+ self.dim = dim
136
+ self.min_pts = min_pts
137
+ self.max_pts = max_pts
138
+ self.method = method
139
+ self.rar_parameters = rar_parameters
140
+
120
141
  assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
121
142
  assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
122
143
 
@@ -125,7 +146,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
125
146
  self.p,
126
147
  self.rar_iter_from_last_sampling,
127
148
  self.rar_iter_nb,
128
- ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
149
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, n_start)
129
150
 
130
151
  if self.method == "grid" and self.dim == 2:
131
152
  perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
@@ -195,13 +216,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
195
216
  self.key, self.omega_border = self.generate_omega_border_data(self.key)
196
217
 
197
218
  def sample_in_omega_domain(
198
- self, keys: Key, sample_size: int
219
+ self, keys: list[PRNGKeyArray], sample_size: int
199
220
  ) -> Float[Array, " n dim"]:
200
221
  if self.method == "uniform":
201
222
  if self.dim == 1:
202
223
  xmin, xmax = self.min_pts[0], self.max_pts[0]
203
224
  return jax.random.uniform(
204
- keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
225
+ *keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
205
226
  )
206
227
 
207
228
  return jnp.concatenate(
@@ -217,10 +238,10 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
217
238
  axis=-1,
218
239
  )
219
240
  else:
220
- return self._qmc_in_omega_domain(keys, sample_size)
241
+ return self._qmc_in_omega_domain(keys[0], sample_size)
221
242
 
222
243
  def _qmc_in_omega_domain(
223
- self, subkey: Key, sample_size: int
244
+ self, subkey: PRNGKeyArray, sample_size: int
224
245
  ) -> Float[Array, "n dim"]:
225
246
  qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
226
247
  if self.dim == 1:
@@ -241,7 +262,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
241
262
  return jnp.array(samples)
242
263
 
243
264
  def sample_in_omega_border_domain(
244
- self, keys: Key, sample_size: int | None = None
265
+ self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
245
266
  ) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
246
267
  sample_size = self.nb if sample_size is None else sample_size
247
268
  if sample_size is None:
@@ -251,6 +272,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
251
272
  xmax = self.max_pts[0]
252
273
  return jnp.array([xmin, xmax]).astype(float)
253
274
  if self.dim == 2:
275
+ assert keys is not None
254
276
  # currently hard-coded the 4 edges for d==2
255
277
  # TODO : find a general & efficient way to sample from the border
256
278
  # (facets) of the hypercube in general dim.
@@ -306,7 +328,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
306
328
  )
307
329
 
308
330
  def qmc_in_omega_border_domain(
309
- self, keys: Key, sample_size: int | None = None
331
+ self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
310
332
  ) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
311
333
  qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
312
334
  sample_size = self.nb if sample_size is None else sample_size
@@ -317,6 +339,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
317
339
  xmax = self.max_pts[0]
318
340
  return jnp.array([xmin, xmax]).astype(float)
319
341
  if self.dim == 2:
342
+ assert keys is not None
320
343
  # currently hard-coded the 4 edges for d==2
321
344
  # TODO : find a general & efficient way to sample from the border
322
345
  # (facets) of the hypercube in general dim.
@@ -362,9 +385,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
362
385
  )
363
386
 
364
387
  def generate_omega_data(
365
- self, key: Key, data_size: int | None = None
388
+ self, key: PRNGKeyArray, data_size: int | None = None
366
389
  ) -> tuple[
367
- Key,
390
+ PRNGKeyArray,
368
391
  Float[Array, " n dim"],
369
392
  ]:
370
393
  r"""
@@ -393,18 +416,19 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
393
416
  omega = jnp.concatenate(xyz_, axis=-1)
394
417
  elif self.method in ["uniform", "sobol", "halton"]:
395
418
  if self.dim == 1 or self.method in ["sobol", "halton"]:
396
- key, subkeys = jax.random.split(key, 2)
419
+ key, subkey = jax.random.split(key, 2)
420
+ omega = self.sample_in_omega_domain([subkey], sample_size=data_size)
397
421
  else:
398
422
  key, *subkeys = jax.random.split(key, self.dim + 1)
399
- omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
423
+ omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
400
424
  else:
401
425
  raise ValueError("Method " + self.method + " is not implemented.")
402
426
  return key, omega
403
427
 
404
428
  def generate_omega_border_data(
405
- self, key: Key, data_size: int | None = None
429
+ self, key: PRNGKeyArray, data_size: int | None = None
406
430
  ) -> tuple[
407
- Key,
431
+ PRNGKeyArray,
408
432
  Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
409
433
  ]:
410
434
  r"""
@@ -433,7 +457,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
433
457
 
434
458
  def _get_omega_operands(
435
459
  self,
436
- ) -> tuple[Key, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None]:
460
+ ) -> tuple[
461
+ PRNGKeyArray, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None
462
+ ]:
437
463
  return (
438
464
  self.key,
439
465
  self.omega,
@@ -475,7 +501,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
475
501
  # handled above
476
502
  )
477
503
  new = eqx.tree_at(
478
- lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
504
+ lambda m: (m.key, m.omega, m.curr_omega_idx), # type: ignore
505
+ self,
506
+ new_attributes,
479
507
  )
480
508
 
481
509
  return new, jax.lax.dynamic_slice(
@@ -487,7 +515,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
487
515
  def _get_omega_border_operands(
488
516
  self,
489
517
  ) -> tuple[
490
- Key,
518
+ PRNGKeyArray,
491
519
  Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
492
520
  int,
493
521
  int | None,
@@ -551,7 +579,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
551
579
  # handled above
552
580
  )
553
581
  new = eqx.tree_at(
554
- lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
582
+ lambda m: (m.key, m.omega_border, m.curr_omega_border_idx), # type: ignore
555
583
  self,
556
584
  new_attributes,
557
585
  )