jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +7 -7
  2. jinns/data/_AbstractDataGenerator.py +1 -1
  3. jinns/data/_Batchs.py +47 -13
  4. jinns/data/_CubicMeshPDENonStatio.py +203 -54
  5. jinns/data/_CubicMeshPDEStatio.py +190 -54
  6. jinns/data/_DataGeneratorODE.py +48 -22
  7. jinns/data/_DataGeneratorObservations.py +75 -32
  8. jinns/data/_DataGeneratorParameter.py +152 -101
  9. jinns/data/__init__.py +2 -1
  10. jinns/data/_utils.py +22 -10
  11. jinns/loss/_DynamicLoss.py +21 -20
  12. jinns/loss/_DynamicLossAbstract.py +51 -36
  13. jinns/loss/_LossODE.py +210 -191
  14. jinns/loss/_LossPDE.py +441 -368
  15. jinns/loss/_abstract_loss.py +60 -25
  16. jinns/loss/_loss_components.py +4 -25
  17. jinns/loss/_loss_utils.py +23 -0
  18. jinns/loss/_loss_weight_updates.py +6 -7
  19. jinns/loss/_loss_weights.py +34 -35
  20. jinns/nn/_abstract_pinn.py +0 -2
  21. jinns/nn/_hyperpinn.py +34 -23
  22. jinns/nn/_mlp.py +5 -4
  23. jinns/nn/_pinn.py +1 -16
  24. jinns/nn/_ppinn.py +5 -16
  25. jinns/nn/_save_load.py +11 -4
  26. jinns/nn/_spinn.py +1 -16
  27. jinns/nn/_spinn_mlp.py +5 -5
  28. jinns/nn/_utils.py +33 -38
  29. jinns/parameters/__init__.py +3 -1
  30. jinns/parameters/_derivative_keys.py +99 -41
  31. jinns/parameters/_params.py +58 -25
  32. jinns/solver/_solve.py +14 -8
  33. jinns/utils/_DictToModuleMeta.py +66 -0
  34. jinns/utils/_ItemizableModule.py +19 -0
  35. jinns/utils/__init__.py +2 -1
  36. jinns/utils/_types.py +25 -15
  37. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  38. jinns-1.6.0.dist-info/RECORD +57 -0
  39. jinns-1.5.0.dist-info/RECORD +0 -55
  40. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  41. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  42. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/__init__.py CHANGED
@@ -1,10 +1,3 @@
1
- # import jinns.data
2
- # import jinns.loss
3
- # import jinns.solver
4
- # import jinns.utils
5
- # import jinns.experimental
6
- # import jinns.parameters
7
- # import jinns.plot
8
1
  from jinns import data as data
9
2
  from jinns import loss as loss
10
3
  from jinns import solver as solver
@@ -16,3 +9,10 @@ from jinns import nn as nn
16
9
  from jinns.solver._solve import solve
17
10
 
18
11
  __all__ = ["nn", "solve"]
12
+
13
+ import warnings
14
+
15
+ warnings.filterwarnings(
16
+ action="ignore",
17
+ message=r"Using `field\(init=False\)`",
18
+ )
@@ -15,5 +15,5 @@ class AbstractDataGenerator(eqx.Module):
15
15
  """
16
16
 
17
17
  @abc.abstractmethod
18
- def get_batch(self) -> tuple[type[Self], AnyBatch]: # type: ignore
18
+ def get_batch(self) -> tuple[Self, AnyBatch]:
19
19
  pass
jinns/data/_Batchs.py CHANGED
@@ -18,25 +18,59 @@ class ObsBatchDict(TypedDict):
18
18
 
19
19
  pinn_in: Float[Array, " obs_batch_size input_dim"]
20
20
  val: Float[Array, " obs_batch_size output_dim"]
21
- eq_params: dict[str, Float[Array, " obs_batch_size 1"]]
21
+ eq_params: (
22
+ eqx.Module | None
23
+ ) # None cause sometime user don't provide observed params
22
24
 
23
25
 
24
26
  class ODEBatch(eqx.Module):
25
27
  temporal_batch: Float[Array, " batch_size"]
26
- param_batch_dict: dict[str, Array] = eqx.field(default=None)
27
- obs_batch_dict: ObsBatchDict = eqx.field(default=None)
28
-
29
-
30
- class PDENonStatioBatch(eqx.Module):
31
- domain_batch: Float[Array, " batch_size 1+dimension"]
32
- border_batch: Float[Array, " batch_size dimension n_facets"] | None
33
- initial_batch: Float[Array, " batch_size dimension"] | None
34
- param_batch_dict: dict[str, Array] = eqx.field(default=None)
35
- obs_batch_dict: ObsBatchDict = eqx.field(default=None)
28
+ param_batch_dict: eqx.Module | None = eqx.field(default=None)
29
+ obs_batch_dict: ObsBatchDict | None = eqx.field(default=None)
36
30
 
37
31
 
38
32
  class PDEStatioBatch(eqx.Module):
39
33
  domain_batch: Float[Array, " batch_size dimension"]
40
34
  border_batch: Float[Array, " batch_size dimension n_facets"] | None
41
- param_batch_dict: dict[str, Array] = eqx.field(default=None)
42
- obs_batch_dict: ObsBatchDict = eqx.field(default=None)
35
+ param_batch_dict: eqx.Module | None
36
+ obs_batch_dict: ObsBatchDict | None
37
+
38
+ # rewrite __init__ to be able to use inheritance for the NonStatio case
39
+ # below. That way PDENonStatioBatch is a subtype of PDEStatioBatch which
40
+ # 1) makes more sense and 2) CubicMeshPDENonStatio.get_batch passes pyright.
41
+ def __init__(
42
+ self,
43
+ *,
44
+ domain_batch: Float[Array, " batch_size dimension"],
45
+ border_batch: Float[Array, " batch_size dimension n_facets"] | None,
46
+ param_batch_dict: eqx.Module | None = None,
47
+ obs_batch_dict: ObsBatchDict | None = None,
48
+ ):
49
+ # TODO: document this ?
50
+ self.domain_batch = domain_batch
51
+ self.border_batch = border_batch
52
+ self.param_batch_dict = param_batch_dict
53
+ self.obs_batch_dict = obs_batch_dict
54
+
55
+
56
+ class PDENonStatioBatch(PDEStatioBatch):
57
+ # TODO: document this ?
58
+ domain_batch: Float[Array, " batch_size 1+dimension"] # Override type
59
+ initial_batch: (
60
+ Float[Array, " batch_size dimension"] | None
61
+ ) # why can it be None ? Examples?
62
+
63
+ def __init__(
64
+ self,
65
+ *,
66
+ domain_batch: Float[Array, " batch_size 1+dimension"],
67
+ border_batch: Float[Array, " batch_size dimension n_facets"] | None,
68
+ initial_batch: Float[Array, " batch_size dimension"] | None,
69
+ param_batch_dict: eqx.Module | None = None,
70
+ obs_batch_dict: ObsBatchDict | None = None,
71
+ ):
72
+ self.domain_batch = domain_batch
73
+ self.border_batch = border_batch
74
+ self.initial_batch = initial_batch
75
+ self.param_batch_dict = param_batch_dict
76
+ self.obs_batch_dict = obs_batch_dict
@@ -7,9 +7,11 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
  import warnings
9
9
  import equinox as eqx
10
+ import numpy as np
10
11
  import jax
11
12
  import jax.numpy as jnp
12
- from jaxtyping import Key, Array, Float
13
+ from scipy.stats import qmc
14
+ from jaxtyping import PRNGKeyArray, Array, Float
13
15
  from jinns.data._Batchs import PDENonStatioBatch
14
16
  from jinns.data._utils import (
15
17
  make_cartesian_product,
@@ -27,7 +29,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
27
29
 
28
30
  Parameters
29
31
  ----------
30
- key : Key
32
+ key : PRNGKeyArray
31
33
  Jax random key to sample new time points and to shuffle batches
32
34
  n : int
33
35
  The number of total $I\times \Omega$ points that will be divided in
@@ -48,9 +50,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
48
50
  among the `nb` points. If None, `domain_batch_size` no
49
51
  mini-batches are used.
50
52
  initial_batch_size : int | None, default=None
51
- The size of the batch of randomly selected points among
52
- the `ni` points. If None no
53
- mini-batches are used.
53
+ The number of randomly selected points among the `ni` initial spatial
54
+ points used for initial condition. If None, no mini-batches are used.
54
55
  dim : int
55
56
  An integer. Dimension of $\Omega$ domain.
56
57
  min_pts : tuple[tuple[Float, Float], ...]
@@ -65,11 +66,13 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
65
66
  The minimum value of the time domain to consider
66
67
  tmax : float
67
68
  The maximum value of the time domain to consider
68
- method : str, default="uniform"
69
+ method : Literal["uniform", "grid", "sobol", "halton"], default="uniform"
69
70
  Either `grid` or `uniform`, default is `uniform`.
70
71
  The method that generates the `nt` time points. `grid` means
71
72
  regularly spaced points over the domain. `uniform` means uniformly
72
- sampled points over the domain
73
+ sampled points over the domain.
74
+ **Note** that Sobol and Halton approaches use scipy modules and will not
75
+ be JIT compatible.
73
76
  rar_parameters : Dict[str, int], default=None
74
77
  Defaults to None: do not use Residual Adaptative Resampling.
75
78
  Otherwise a dictionary with keys
@@ -90,13 +93,14 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
90
93
  then corresponds to the initial number of omega points we train the PINN.
91
94
  """
92
95
 
93
- tmin: Float = eqx.field(kw_only=True)
94
- tmax: Float = eqx.field(kw_only=True)
95
- ni: int = eqx.field(kw_only=True, static=True)
96
- domain_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
97
- initial_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
98
- border_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
96
+ tmin: float
97
+ tmax: float
98
+ ni: int = eqx.field(static=True)
99
+ domain_batch_size: int | None = eqx.field(static=True)
100
+ initial_batch_size: int | None = eqx.field(static=True)
101
+ border_batch_size: int | None = eqx.field(static=True)
99
102
 
103
+ # --- Below fields are not passed as arguments to __init__
100
104
  curr_domain_idx: int = eqx.field(init=False)
101
105
  curr_initial_idx: int = eqx.field(init=False)
102
106
  curr_border_idx: int = eqx.field(init=False)
@@ -106,13 +110,32 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
106
110
  )
107
111
  initial: Float[Array, " ni dim"] | None = eqx.field(init=False)
108
112
 
109
- def __post_init__(self):
113
+ def __init__(
114
+ self,
115
+ tmin: float,
116
+ tmax: float,
117
+ ni: int,
118
+ domain_batch_size: int | None = None,
119
+ initial_batch_size: int | None = None,
120
+ border_batch_size: int | None = None,
121
+ **kwargs, # kwargs for CubicMeshPDEStatio.__init__
122
+ ):
110
123
  """
111
124
  Note that neither __init__ or __post_init__ are called when udating a
112
125
  Module with eqx.tree_at!
113
126
  """
114
- super().__post_init__() # because __init__ or __post_init__ of Base
115
- # class is not automatically called
127
+ # sanity check
128
+ if ni is None:
129
+ raise ValueError("`ni` cannot be None.")
130
+
131
+ super().__init__(**kwargs)
132
+ self.tmin = tmin
133
+ self.tmax = tmax
134
+ self.ni = ni
135
+
136
+ self.domain_batch_size = domain_batch_size
137
+ self.initial_batch_size = initial_batch_size
138
+ self.border_batch_size = border_batch_size
116
139
 
117
140
  if self.method == "grid":
118
141
  # NOTE we must redo the sampling with the square root number of samples
@@ -140,7 +163,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
140
163
  )
141
164
  self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
142
165
 
143
- # NOTE
166
+ # NOTE below re-do CubicMeshPDE.__init__() ? Maybe useless?
144
167
  (
145
168
  self.n_start,
146
169
  self.p,
@@ -150,9 +173,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
150
173
  elif self.method == "uniform":
151
174
  self.key, domain_times = self.generate_time_data(self.key, self.n)
152
175
  self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
176
+ elif self.method in ["sobol", "halton"]:
177
+ self.key, self.domain = self.qmc_in_time_omega_domain(self.key, self.n)
153
178
  else:
154
179
  raise ValueError(
155
- f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
180
+ f'Bad value for method. Got {self.method}, expected "grid" or "uniform" or "sobol" or "halton"'
156
181
  )
157
182
 
158
183
  if self.domain_batch_size is None:
@@ -172,7 +197,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
172
197
  " a multiple of 2xd (the # of faces of a d-dimensional cube)"
173
198
  )
174
199
  # the check below concern omega_border_batch_size for dim > 1 in
175
- # super.__post_init__. Here it concerns all dim values since our
200
+ # super.__init__. Here it concerns all dim values since our
176
201
  # border_batch is the concatenation or cartesian product with times
177
202
  if (
178
203
  self.border_batch_size is not None
@@ -182,21 +207,28 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
182
207
  "number of points per facets (nb//2*self.dim)"
183
208
  " cannot be lower than border batch size"
184
209
  )
185
- self.key, boundary_times = self.generate_time_data(
186
- self.key, self.nb // (2 * self.dim)
187
- )
188
- boundary_times = boundary_times.reshape(-1, 1, 1)
189
- boundary_times = jnp.repeat(
190
- boundary_times, self.omega_border.shape[-1], axis=2
191
- )
192
- if self.dim == 1:
193
- self.border = make_cartesian_product(
194
- boundary_times, self.omega_border[None, None]
210
+ if self.method in ["grid", "uniform"]:
211
+ self.key, boundary_times = self.generate_time_data(
212
+ self.key, self.nb // (2 * self.dim)
213
+ )
214
+ boundary_times = boundary_times.reshape(-1, 1, 1)
215
+ boundary_times = jnp.repeat(
216
+ boundary_times, self.omega_border.shape[-1], axis=2
195
217
  )
218
+ if self.dim == 1:
219
+ self.border = make_cartesian_product(
220
+ boundary_times, self.omega_border[None, None]
221
+ )
222
+ else:
223
+ self.border = jnp.concatenate(
224
+ [boundary_times, self.omega_border], axis=1
225
+ )
196
226
  else:
197
- self.border = jnp.concatenate(
198
- [boundary_times, self.omega_border], axis=1
227
+ self.key, self.border = self.qmc_in_time_omega_border_domain(
228
+ self.key,
229
+ self.nb, # type: ignore (see inside the fun)
199
230
  )
231
+
200
232
  if self.border_batch_size is None:
201
233
  self.curr_border_idx = 0
202
234
  else:
@@ -208,15 +240,31 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
208
240
  self.border_batch_size = None
209
241
  self.curr_border_idx = 0
210
242
 
211
- if self.ni is not None:
212
- perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
213
- if self.ni != perfect_sq:
214
- warnings.warn(
215
- "Grid sampling is requested in dimension 2 with a non"
216
- f" perfect square dataset size (self.ni = {self.ni})."
217
- f" Modifying self.ni to self.ni = {perfect_sq}."
243
+ if ni is not None:
244
+ if self.method == "grid":
245
+ perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
246
+ if self.ni != perfect_sq:
247
+ warnings.warn(
248
+ "Grid sampling is requested in dimension 2 with a non"
249
+ f" perfect square dataset size (self.ni = {self.ni})."
250
+ f" Modifying self.ni to self.ni = {perfect_sq}."
251
+ )
252
+ self.ni = perfect_sq
253
+ if self.method in ["sobol", "halton"]:
254
+ log2_n = jnp.log2(self.ni)
255
+ lower_pow = 2 ** jnp.floor(log2_n)
256
+ higher_pow = 2 ** jnp.ceil(log2_n)
257
+ closest_power_of_two = (
258
+ lower_pow
259
+ if (self.ni - lower_pow) < (higher_pow - self.ni)
260
+ else higher_pow
218
261
  )
219
- self.ni = perfect_sq
262
+ if self.n != closest_power_of_two:
263
+ warnings.warn(
264
+ f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
265
+ f"Modfiying self.n from {self.ni} to {closest_power_of_two}.",
266
+ )
267
+ self.ni = int(closest_power_of_two)
220
268
  self.key, self.initial = self.generate_omega_data(
221
269
  self.key, data_size=self.ni
222
270
  )
@@ -235,8 +283,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
235
283
  self.omega_border = None
236
284
 
237
285
  def generate_time_data(
238
- self, key: Key, nt: int
239
- ) -> tuple[Key, Float[Array, " nt 1"]]:
286
+ self, key: PRNGKeyArray, nt: int
287
+ ) -> tuple[PRNGKeyArray, Float[Array, " nt 1"]]:
240
288
  """
241
289
  Construct a complete set of `nt` time points according to the
242
290
  specified `self.method`
@@ -245,21 +293,122 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
245
293
  if self.method == "grid":
246
294
  partial_times = (self.tmax - self.tmin) / nt
247
295
  return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
248
- if self.method == "uniform":
296
+ elif self.method in ["uniform", "sobol", "halton"]:
249
297
  return key, self.sample_in_time_domain(subkey, nt)
250
298
  raise ValueError("Method " + self.method + " is not implemented.")
251
299
 
252
- def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
253
- return jax.random.uniform(
254
- key,
255
- (nt, 1),
256
- minval=self.tmin,
257
- maxval=self.tmax,
300
+ def sample_in_time_domain(
301
+ self, key: PRNGKeyArray, nt: int
302
+ ) -> Float[Array, " nt 1"]:
303
+ return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
304
+
305
+ def qmc_in_time_omega_domain(
306
+ self, key: PRNGKeyArray, sample_size: int
307
+ ) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]]:
308
+ """
309
+ Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
310
+ We generate time and omega samples jointly
311
+ """
312
+ key, subkey = jax.random.split(key, 2)
313
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
314
+ sampler = qmc_generator(
315
+ d=self.dim + 1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
316
+ )
317
+ samples = sampler.random(n=sample_size)
318
+ samples[:, 1:] = qmc.scale(
319
+ samples[:, 1:], l_bounds=self.min_pts, u_bounds=self.max_pts
320
+ ) # We scale omega domain to be in (min_pts, max_pts)
321
+ return key, jnp.array(samples)
322
+
323
+ def qmc_in_time_omega_border_domain(
324
+ self, key: PRNGKeyArray, sample_size: int | None = None
325
+ ) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]] | None:
326
+ """
327
+ For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
328
+
329
+ We need to do some type ignore in this function because we have lost
330
+ the type narrowing from post_init, type checkers only narrow at function level and because we cannot narrow a class attribute.
331
+ """
332
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
333
+ sample_size = self.nb if sample_size is None else sample_size
334
+ if sample_size is None:
335
+ return None
336
+ if self.dim == 1:
337
+ key, subkey = jax.random.split(key, 2)
338
+ qmc_seq = qmc_generator(
339
+ d=1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
340
+ )
341
+ boundary_times = jnp.array(
342
+ qmc_seq.random(self.nb // (2 * self.dim)) # type: ignore
343
+ )
344
+ boundary_times = boundary_times.reshape(-1, 1, 1)
345
+ boundary_times = jnp.repeat(
346
+ boundary_times,
347
+ self.omega_border.shape[-1], # type: ignore
348
+ axis=2,
349
+ )
350
+ return key, make_cartesian_product(
351
+ boundary_times,
352
+ self.omega_border[None, None], # type: ignore
353
+ )
354
+ if self.dim == 2:
355
+ # currently hard-coded the 4 edges for d==2
356
+ # TODO : find a general & efficient way to sample from the border
357
+ # (facets) of the hypercube in general dim.
358
+ key, *subkeys = jax.random.split(key, 5)
359
+ facet_n = sample_size // (2 * self.dim)
360
+
361
+ def generate_qmc_sample(key, min_val, max_val):
362
+ qmc_seq = qmc_generator(
363
+ d=2,
364
+ scramble=True,
365
+ rng=np.random.default_rng(np.uint32(key)),
366
+ )
367
+ u = qmc_seq.random(n=facet_n)
368
+ u[:, 1:2] = qmc.scale(u[:, 1:2], l_bounds=min_val, u_bounds=max_val)
369
+ return jnp.array(u)
370
+
371
+ xmin_sample = generate_qmc_sample(
372
+ subkeys[0], self.min_pts[1], self.max_pts[1]
373
+ ) # [t,x,y]
374
+ xmin = jnp.hstack(
375
+ [
376
+ xmin_sample[:, 0:1],
377
+ self.min_pts[0] * jnp.ones((facet_n, 1)),
378
+ xmin_sample[:, 1:2],
379
+ ]
380
+ )
381
+ xmax_sample = generate_qmc_sample(
382
+ subkeys[1], self.min_pts[1], self.max_pts[1]
383
+ )
384
+ xmax = jnp.hstack(
385
+ [
386
+ xmax_sample[:, 0:1],
387
+ self.max_pts[0] * jnp.ones((facet_n, 1)),
388
+ xmax_sample[:, 1:2],
389
+ ]
390
+ )
391
+ ymin = jnp.hstack(
392
+ [
393
+ generate_qmc_sample(subkeys[2], self.min_pts[0], self.max_pts[0]),
394
+ self.min_pts[1] * jnp.ones((facet_n, 1)),
395
+ ]
396
+ )
397
+ ymax = jnp.hstack(
398
+ [
399
+ generate_qmc_sample(subkeys[3], self.min_pts[0], self.max_pts[0]),
400
+ self.max_pts[1] * jnp.ones((facet_n, 1)),
401
+ ]
402
+ )
403
+ return key, jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
404
+ raise NotImplementedError(
405
+ "Generation of the border of a cube in dimension > 2 is not "
406
+ + f"implemented yet. You are asking for generation in dimension d={self.dim}."
258
407
  )
259
408
 
260
409
  def _get_domain_operands(
261
410
  self,
262
- ) -> tuple[Key, Float[Array, " n 1+dim"], int, int | None, Array | None]:
411
+ ) -> tuple[PRNGKeyArray, Float[Array, " n 1+dim"], int, int | None, Array | None]:
263
412
  return (
264
413
  self.key,
265
414
  self.domain,
@@ -296,7 +445,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
296
445
  # handled above
297
446
  )
298
447
  new = eqx.tree_at(
299
- lambda m: (m.key, m.domain, m.curr_domain_idx),
448
+ lambda m: (m.key, m.domain, m.curr_domain_idx), # type: ignore
300
449
  self,
301
450
  new_attributes,
302
451
  )
@@ -309,7 +458,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
309
458
  def _get_border_operands(
310
459
  self,
311
460
  ) -> tuple[
312
- Key,
461
+ PRNGKeyArray,
313
462
  Float[Array, " nb 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None,
314
463
  int,
315
464
  int | None,
@@ -355,7 +504,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
355
504
  # handled above
356
505
  )
357
506
  new = eqx.tree_at(
358
- lambda m: (m.key, m.border, m.curr_border_idx),
507
+ lambda m: (m.key, m.border, m.curr_border_idx), # type: ignore
359
508
  self,
360
509
  new_attributes,
361
510
  )
@@ -372,7 +521,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
372
521
 
373
522
  def _get_initial_operands(
374
523
  self,
375
- ) -> tuple[Key, Float[Array, " ni dim"] | None, int, int | None, None]:
524
+ ) -> tuple[PRNGKeyArray, Float[Array, " ni dim"] | None, int, int | None, None]:
376
525
  return (
377
526
  self.key,
378
527
  self.initial,
@@ -401,7 +550,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
401
550
  # handled above
402
551
  )
403
552
  new = eqx.tree_at(
404
- lambda m: (m.key, m.initial, m.curr_initial_idx),
553
+ lambda m: (m.key, m.initial, m.curr_initial_idx), # type: ignore
405
554
  self,
406
555
  new_attributes,
407
556
  )