jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/__init__.py CHANGED
@@ -1,8 +1,18 @@
1
- import jinns.data
2
- import jinns.loss
3
- import jinns.solver
4
- import jinns.utils
5
- import jinns.experimental
6
- import jinns.parameters
7
- import jinns.plot
1
+ # import jinns.data
2
+ # import jinns.loss
3
+ # import jinns.solver
4
+ # import jinns.utils
5
+ # import jinns.experimental
6
+ # import jinns.parameters
7
+ # import jinns.plot
8
+ from jinns import data as data
9
+ from jinns import loss as loss
10
+ from jinns import solver as solver
11
+ from jinns import utils as utils
12
+ from jinns import experimental as experimental
13
+ from jinns import parameters as parameters
14
+ from jinns import plot as plot
15
+ from jinns import nn as nn
8
16
  from jinns.solver._solve import solve
17
+
18
+ __all__ = ["nn", "solve"]
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+ import abc
3
+ from typing import Self, TYPE_CHECKING
4
+ import equinox as eqx
5
+
6
+ if TYPE_CHECKING:
7
+ from jinns.utils._types import AnyBatch
8
+
9
+
10
+ class AbstractDataGenerator(eqx.Module):
11
+ """
12
+ Basically just a way to add a get_batch() to an eqx.Module.
13
+ The way to go for correct type hints apparently
14
+ https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
15
+ """
16
+
17
+ @abc.abstractmethod
18
+ def get_batch(self) -> tuple[type[Self], AnyBatch]: # type: ignore
19
+ pass
jinns/data/_Batchs.py CHANGED
@@ -1,23 +1,42 @@
1
+ from typing import TypedDict
1
2
  import equinox as eqx
2
3
  from jaxtyping import Float, Array
3
4
 
4
5
 
6
+ class ObsBatchDict(TypedDict):
7
+ """
8
+ Keys:
9
+ -pinn_in, a mini batch of pinn inputs
10
+ -val, a mini batch of corresponding observations
11
+ -eq_params, a dictionary with entry names found in `params["eq_params"]`
12
+ and values giving the correspond parameter value for the couple (input,
13
+ value) mentioned before).
14
+
15
+ A TypedDict is the correct way to handle type hints for dict with fixed set of keys
16
+ https://peps.python.org/pep-0589/
17
+ """
18
+
19
+ pinn_in: Float[Array, " obs_batch_size input_dim"]
20
+ val: Float[Array, " obs_batch_size output_dim"]
21
+ eq_params: dict[str, Float[Array, " obs_batch_size 1"]]
22
+
23
+
5
24
  class ODEBatch(eqx.Module):
6
- temporal_batch: Float[Array, "batch_size"]
7
- param_batch_dict: dict = eqx.field(default=None)
8
- obs_batch_dict: dict = eqx.field(default=None)
25
+ temporal_batch: Float[Array, " batch_size"]
26
+ param_batch_dict: dict[str, Array] = eqx.field(default=None)
27
+ obs_batch_dict: ObsBatchDict = eqx.field(default=None)
9
28
 
10
29
 
11
30
  class PDENonStatioBatch(eqx.Module):
12
- domain_batch: Float[Array, "batch_size 1+dimension"]
13
- border_batch: Float[Array, "batch_size dimension n_facets"]
14
- initial_batch: Float[Array, "batch_size dimension"]
15
- param_batch_dict: dict = eqx.field(default=None)
16
- obs_batch_dict: dict = eqx.field(default=None)
31
+ domain_batch: Float[Array, " batch_size 1+dimension"]
32
+ border_batch: Float[Array, " batch_size dimension n_facets"] | None
33
+ initial_batch: Float[Array, " batch_size dimension"] | None
34
+ param_batch_dict: dict[str, Array] = eqx.field(default=None)
35
+ obs_batch_dict: ObsBatchDict = eqx.field(default=None)
17
36
 
18
37
 
19
38
  class PDEStatioBatch(eqx.Module):
20
- domain_batch: Float[Array, "batch_size dimension"]
21
- border_batch: Float[Array, "batch_size dimension n_facets"]
22
- param_batch_dict: dict = eqx.field(default=None)
23
- obs_batch_dict: dict = eqx.field(default=None)
39
+ domain_batch: Float[Array, " batch_size dimension"]
40
+ border_batch: Float[Array, " batch_size dimension n_facets"] | None
41
+ param_batch_dict: dict[str, Array] = eqx.field(default=None)
42
+ obs_batch_dict: ObsBatchDict = eqx.field(default=None)
@@ -0,0 +1,431 @@
1
+ """
2
+ Define the DataGenerators modules
3
+ """
4
+
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+ import warnings
9
+ import equinox as eqx
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jaxtyping import Key, Array, Float
13
+ from jinns.data._Batchs import PDENonStatioBatch
14
+ from jinns.data._utils import (
15
+ make_cartesian_product,
16
+ _check_and_set_rar_parameters,
17
+ _reset_or_increment,
18
+ )
19
+ from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
20
+
21
+
22
+ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
23
+ r"""
24
+ A class implementing data generator object for non stationary partial
25
+ differential equations. Formally, it extends `CubicMeshPDEStatio`
26
+ to include a temporal batch.
27
+
28
+ Parameters
29
+ ----------
30
+ key : Key
31
+ Jax random key to sample new time points and to shuffle batches
32
+ n : int
33
+ The number of total $I\times \Omega$ points that will be divided in
34
+ batches. Batches are made so that each data point is seen only
35
+ once during 1 epoch.
36
+ nb : int | None
37
+ The total number of points in $\partial\Omega$. Can be None if no
38
+ boundary condition is specified.
39
+ ni : int
40
+ The number of total $\Omega$ points at $t=0$ that will be divided in
41
+ batches. Batches are made so that each data point is seen only
42
+ once during 1 epoch.
43
+ domain_batch_size : int | None, default=None
44
+ The size of the batch of randomly selected points among
45
+ the `n` points. If None no mini-batches are used.
46
+ border_batch_size : int | None, default=None
47
+ The size of the batch of points randomly selected
48
+ among the `nb` points. If None, `domain_batch_size` no
49
+ mini-batches are used.
50
+ initial_batch_size : int | None, default=None
51
+ The size of the batch of randomly selected points among
52
+ the `ni` points. If None no
53
+ mini-batches are used.
54
+ dim : int
55
+ An integer. Dimension of $\Omega$ domain.
56
+ min_pts : tuple[tuple[Float, Float], ...]
57
+ A tuple of minimum values of the domain along each dimension. For a sampling
58
+ in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
59
+ x_{n, min})$
60
+ max_pts : tuple[tuple[Float, Float], ...]
61
+ A tuple of maximum values of the domain along each dimension. For a sampling
62
+ in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
63
+ x_{n,max})$
64
+ tmin : float
65
+ The minimum value of the time domain to consider
66
+ tmax : float
67
+ The maximum value of the time domain to consider
68
+ method : str, default="uniform"
69
+ Either `grid` or `uniform`, default is `uniform`.
70
+ The method that generates the `nt` time points. `grid` means
71
+ regularly spaced points over the domain. `uniform` means uniformly
72
+ sampled points over the domain
73
+ rar_parameters : Dict[str, int], default=None
74
+ Defaults to None: do not use Residual Adaptative Resampling.
75
+ Otherwise a dictionary with keys
76
+ - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
77
+ - `update_every`: the number of gradient steps taken between
78
+ each update of collocation points in the RAR algo.
79
+ - `sample_size`: the size of the sample from which we will select new
80
+ collocation points.
81
+ - `selected_sample_size`: the number of selected
82
+ points from the sample to be added to the current collocation
83
+ points.
84
+ n_start : int, default=None
85
+ Defaults to None. The effective size of n used at start time.
86
+ This value must be
87
+ provided when rar_parameters is not None. Otherwise we set internally
88
+ n_start = n and this is hidden from the user.
89
+ In RAR, n_start
90
+ then corresponds to the initial number of omega points we train the PINN.
91
+ """
92
+
93
+ tmin: Float = eqx.field(kw_only=True)
94
+ tmax: Float = eqx.field(kw_only=True)
95
+ ni: int = eqx.field(kw_only=True, static=True)
96
+ domain_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
97
+ initial_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
98
+ border_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
99
+
100
+ curr_domain_idx: int = eqx.field(init=False)
101
+ curr_initial_idx: int = eqx.field(init=False)
102
+ curr_border_idx: int = eqx.field(init=False)
103
+ domain: Float[Array, " n 1+dim"] = eqx.field(init=False)
104
+ border: Float[Array, " (nb//2) 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None = (
105
+ eqx.field(init=False)
106
+ )
107
+ initial: Float[Array, " ni dim"] | None = eqx.field(init=False)
108
+
109
+ def __post_init__(self):
110
+ """
111
+ Note that neither __init__ or __post_init__ are called when udating a
112
+ Module with eqx.tree_at!
113
+ """
114
+ super().__post_init__() # because __init__ or __post_init__ of Base
115
+ # class is not automatically called
116
+
117
+ if self.method == "grid":
118
+ # NOTE we must redo the sampling with the square root number of samples
119
+ # and then take the cartesian product
120
+ self.n = int(jnp.round(jnp.sqrt(self.n)) ** 2)
121
+ if self.dim == 2:
122
+ # in the case of grid sampling in 2D in dim 2 in non-statio,
123
+ # self.n needs to be a perfect ^4, because there is the
124
+ # cartesian product with time domain which is also present
125
+ perfect_4 = int(jnp.round(self.n**0.25) ** 4)
126
+ if self.n != perfect_4:
127
+ warnings.warn(
128
+ "Grid sampling is requested in dimension 2 in non"
129
+ " stationary setting with a non"
130
+ f" perfect square dataset size (self.n = {self.n})."
131
+ f" Modifying self.n to self.n = {perfect_4}."
132
+ )
133
+ self.n = perfect_4
134
+ self.key, half_domain_times = self.generate_time_data(
135
+ self.key, int(jnp.round(jnp.sqrt(self.n)))
136
+ )
137
+
138
+ self.key, half_domain_omega = self.generate_omega_data(
139
+ self.key, data_size=int(jnp.round(jnp.sqrt(self.n)))
140
+ )
141
+ self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
142
+
143
+ # NOTE
144
+ (
145
+ self.n_start,
146
+ self.p,
147
+ self.rar_iter_from_last_sampling,
148
+ self.rar_iter_nb,
149
+ ) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
150
+ elif self.method == "uniform":
151
+ self.key, domain_times = self.generate_time_data(self.key, self.n)
152
+ self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
153
+ else:
154
+ raise ValueError(
155
+ f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
156
+ )
157
+
158
+ if self.domain_batch_size is None:
159
+ self.curr_domain_idx = 0
160
+ else:
161
+ self.curr_domain_idx = self.n + self.domain_batch_size
162
+ # to be sure there is a shuffling at first get_batch()
163
+ if self.nb is not None:
164
+ assert (
165
+ self.omega_border is not None
166
+ ) # this needs to have been instanciated in super.__post_init__()
167
+ # the check below has already been done in super.__post_init__ if
168
+ # dim > 1. Here we retest it in whatever dim
169
+ if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
170
+ raise ValueError(
171
+ "number of border point must be"
172
+ " a multiple of 2xd (the # of faces of a d-dimensional cube)"
173
+ )
174
+ # the check below concern omega_border_batch_size for dim > 1 in
175
+ # super.__post_init__. Here it concerns all dim values since our
176
+ # border_batch is the concatenation or cartesian product with times
177
+ if (
178
+ self.border_batch_size is not None
179
+ and self.nb // (2 * self.dim) < self.border_batch_size
180
+ ):
181
+ raise ValueError(
182
+ "number of points per facets (nb//2*self.dim)"
183
+ " cannot be lower than border batch size"
184
+ )
185
+ self.key, boundary_times = self.generate_time_data(
186
+ self.key, self.nb // (2 * self.dim)
187
+ )
188
+ boundary_times = boundary_times.reshape(-1, 1, 1)
189
+ boundary_times = jnp.repeat(
190
+ boundary_times, self.omega_border.shape[-1], axis=2
191
+ )
192
+ if self.dim == 1:
193
+ self.border = make_cartesian_product(
194
+ boundary_times, self.omega_border[None, None]
195
+ )
196
+ else:
197
+ self.border = jnp.concatenate(
198
+ [boundary_times, self.omega_border], axis=1
199
+ )
200
+ if self.border_batch_size is None:
201
+ self.curr_border_idx = 0
202
+ else:
203
+ self.curr_border_idx = self.nb + self.border_batch_size
204
+ # to be sure there is a shuffling at first get_batch()
205
+
206
+ else:
207
+ self.border = None
208
+ self.border_batch_size = None
209
+ self.curr_border_idx = 0
210
+
211
+ if self.ni is not None:
212
+ perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
213
+ if self.ni != perfect_sq:
214
+ warnings.warn(
215
+ "Grid sampling is requested in dimension 2 with a non"
216
+ f" perfect square dataset size (self.ni = {self.ni})."
217
+ f" Modifying self.ni to self.ni = {perfect_sq}."
218
+ )
219
+ self.ni = perfect_sq
220
+ self.key, self.initial = self.generate_omega_data(
221
+ self.key, data_size=self.ni
222
+ )
223
+
224
+ if self.initial_batch_size is None or self.initial_batch_size == self.ni:
225
+ self.curr_initial_idx = 0
226
+ else:
227
+ self.curr_initial_idx = self.ni + self.initial_batch_size
228
+ # to be sure there is a shuffling at first get_batch()
229
+ else:
230
+ self.initial = None
231
+ self.initial_batch_size = None
232
+
233
+ # the following attributes will not be used anymore
234
+ self.omega = None # type: ignore
235
+ self.omega_border = None
236
+
237
+ def generate_time_data(
238
+ self, key: Key, nt: int
239
+ ) -> tuple[Key, Float[Array, " nt 1"]]:
240
+ """
241
+ Construct a complete set of `nt` time points according to the
242
+ specified `self.method`
243
+ """
244
+ key, subkey = jax.random.split(key, 2)
245
+ if self.method == "grid":
246
+ partial_times = (self.tmax - self.tmin) / nt
247
+ return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
248
+ if self.method == "uniform":
249
+ return key, self.sample_in_time_domain(subkey, nt)
250
+ raise ValueError("Method " + self.method + " is not implemented.")
251
+
252
+ def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
253
+ return jax.random.uniform(
254
+ key,
255
+ (nt, 1),
256
+ minval=self.tmin,
257
+ maxval=self.tmax,
258
+ )
259
+
260
+ def _get_domain_operands(
261
+ self,
262
+ ) -> tuple[Key, Float[Array, " n 1+dim"], int, int | None, Array | None]:
263
+ return (
264
+ self.key,
265
+ self.domain,
266
+ self.curr_domain_idx,
267
+ self.domain_batch_size,
268
+ self.p,
269
+ )
270
+
271
+ def domain_batch(
272
+ self,
273
+ ) -> tuple[CubicMeshPDENonStatio, Float[Array, " domain_batch_size 1+dim"]]:
274
+ if self.domain_batch_size is None or self.domain_batch_size == self.n:
275
+ # Avoid unnecessary reshuffling
276
+ return self, self.domain
277
+
278
+ bstart = self.curr_domain_idx
279
+ bend = bstart + self.domain_batch_size
280
+
281
+ # Compute the effective number of used collocation points
282
+ if self.rar_parameters is not None:
283
+ n_eff = (
284
+ self.n_start
285
+ + self.rar_iter_nb # type: ignore
286
+ * self.rar_parameters["selected_sample_size"]
287
+ )
288
+ else:
289
+ n_eff = self.n
290
+
291
+ new_attributes = _reset_or_increment(
292
+ bend,
293
+ n_eff,
294
+ self._get_domain_operands(), # type: ignore
295
+ # ignore since the case self.domain_batch_size is None has been
296
+ # handled above
297
+ )
298
+ new = eqx.tree_at(
299
+ lambda m: (m.key, m.domain, m.curr_domain_idx),
300
+ self,
301
+ new_attributes,
302
+ )
303
+ return new, jax.lax.dynamic_slice(
304
+ new.domain,
305
+ start_indices=(new.curr_domain_idx, 0),
306
+ slice_sizes=(new.domain_batch_size, new.dim + 1),
307
+ )
308
+
309
+ def _get_border_operands(
310
+ self,
311
+ ) -> tuple[
312
+ Key,
313
+ Float[Array, " nb 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None,
314
+ int,
315
+ int | None,
316
+ None,
317
+ ]:
318
+ return (
319
+ self.key,
320
+ self.border,
321
+ self.curr_border_idx,
322
+ self.border_batch_size,
323
+ None,
324
+ )
325
+
326
+ def border_batch(
327
+ self,
328
+ ) -> tuple[
329
+ CubicMeshPDENonStatio,
330
+ Float[Array, " border_batch_size 1+1 2"]
331
+ | Float[Array, " border_batch_size 2+1 4"]
332
+ | None,
333
+ ]:
334
+ if self.nb is None or self.border is None:
335
+ # Avoid unnecessary reshuffling
336
+ return self, None
337
+
338
+ if (
339
+ self.border_batch_size is None
340
+ or self.border_batch_size == self.nb // 2**self.dim
341
+ ):
342
+ # Avoid unnecessary reshuffling
343
+ return self, self.border
344
+
345
+ bstart = self.curr_border_idx
346
+ bend = bstart + self.border_batch_size
347
+
348
+ n_eff = self.border.shape[0]
349
+
350
+ new_attributes = _reset_or_increment(
351
+ bend,
352
+ n_eff,
353
+ self._get_border_operands(), # type: ignore
354
+ # ignore since the case self.border_batch_size is None has been
355
+ # handled above
356
+ )
357
+ new = eqx.tree_at(
358
+ lambda m: (m.key, m.border, m.curr_border_idx),
359
+ self,
360
+ new_attributes,
361
+ )
362
+
363
+ return new, jax.lax.dynamic_slice(
364
+ new.border,
365
+ start_indices=(new.curr_border_idx, 0, 0),
366
+ slice_sizes=(
367
+ new.border_batch_size,
368
+ new.dim + 1,
369
+ 2 * new.dim,
370
+ ),
371
+ )
372
+
373
+ def _get_initial_operands(
374
+ self,
375
+ ) -> tuple[Key, Float[Array, " ni dim"] | None, int, int | None, None]:
376
+ return (
377
+ self.key,
378
+ self.initial,
379
+ self.curr_initial_idx,
380
+ self.initial_batch_size,
381
+ None,
382
+ )
383
+
384
+ def initial_batch(
385
+ self,
386
+ ) -> tuple[CubicMeshPDENonStatio, Float[Array, " initial_batch_size dim"] | None]:
387
+ if self.initial_batch_size is None or self.initial_batch_size == self.ni:
388
+ # Avoid unnecessary reshuffling
389
+ return self, self.initial
390
+
391
+ bstart = self.curr_initial_idx
392
+ bend = bstart + self.initial_batch_size
393
+
394
+ n_eff = self.ni
395
+
396
+ new_attributes = _reset_or_increment(
397
+ bend,
398
+ n_eff,
399
+ self._get_initial_operands(), # type: ignore
400
+ # ignore since the case self.initial_batch_size is None has been
401
+ # handled above
402
+ )
403
+ new = eqx.tree_at(
404
+ lambda m: (m.key, m.initial, m.curr_initial_idx),
405
+ self,
406
+ new_attributes,
407
+ )
408
+ return new, jax.lax.dynamic_slice(
409
+ new.initial,
410
+ start_indices=(new.curr_initial_idx, 0),
411
+ slice_sizes=(new.initial_batch_size, new.dim),
412
+ )
413
+
414
+ def get_batch(self) -> tuple[CubicMeshPDENonStatio, PDENonStatioBatch]:
415
+ """
416
+ Generic method to return a batch. Here we call `self.domain_batch()`,
417
+ `self.border_batch()` and `self.initial_batch()`
418
+ """
419
+ new, domain = self.domain_batch()
420
+ if self.border is not None:
421
+ new, border = new.border_batch()
422
+ else:
423
+ border = None
424
+ if self.initial is not None:
425
+ new, initial = new.initial_batch()
426
+ else:
427
+ initial = None
428
+
429
+ return new, PDENonStatioBatch(
430
+ domain_batch=domain, border_batch=border, initial_batch=initial
431
+ )