jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
|
@@ -1,8 +1,18 @@
|
|
|
1
|
-
import jinns.data
|
|
2
|
-
import jinns.loss
|
|
3
|
-
import jinns.solver
|
|
4
|
-
import jinns.utils
|
|
5
|
-
import jinns.experimental
|
|
6
|
-
import jinns.parameters
|
|
7
|
-
import jinns.plot
|
|
1
|
+
# import jinns.data
|
|
2
|
+
# import jinns.loss
|
|
3
|
+
# import jinns.solver
|
|
4
|
+
# import jinns.utils
|
|
5
|
+
# import jinns.experimental
|
|
6
|
+
# import jinns.parameters
|
|
7
|
+
# import jinns.plot
|
|
8
|
+
from jinns import data as data
|
|
9
|
+
from jinns import loss as loss
|
|
10
|
+
from jinns import solver as solver
|
|
11
|
+
from jinns import utils as utils
|
|
12
|
+
from jinns import experimental as experimental
|
|
13
|
+
from jinns import parameters as parameters
|
|
14
|
+
from jinns import plot as plot
|
|
15
|
+
from jinns import nn as nn
|
|
8
16
|
from jinns.solver._solve import solve
|
|
17
|
+
|
|
18
|
+
__all__ = ["nn", "solve"]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import abc
|
|
3
|
+
from typing import Self, TYPE_CHECKING
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from jinns.utils._types import AnyBatch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AbstractDataGenerator(eqx.Module):
|
|
11
|
+
"""
|
|
12
|
+
Basically just a way to add a get_batch() to an eqx.Module.
|
|
13
|
+
The way to go for correct type hints apparently
|
|
14
|
+
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@abc.abstractmethod
|
|
18
|
+
def get_batch(self) -> tuple[type[Self], AnyBatch]: # type: ignore
|
|
19
|
+
pass
|
jinns/data/_Batchs.py
CHANGED
|
@@ -1,23 +1,42 @@
|
|
|
1
|
+
from typing import TypedDict
|
|
1
2
|
import equinox as eqx
|
|
2
3
|
from jaxtyping import Float, Array
|
|
3
4
|
|
|
4
5
|
|
|
6
|
+
class ObsBatchDict(TypedDict):
|
|
7
|
+
"""
|
|
8
|
+
Keys:
|
|
9
|
+
-pinn_in, a mini batch of pinn inputs
|
|
10
|
+
-val, a mini batch of corresponding observations
|
|
11
|
+
-eq_params, a dictionary with entry names found in `params["eq_params"]`
|
|
12
|
+
and values giving the correspond parameter value for the couple (input,
|
|
13
|
+
value) mentioned before).
|
|
14
|
+
|
|
15
|
+
A TypedDict is the correct way to handle type hints for dict with fixed set of keys
|
|
16
|
+
https://peps.python.org/pep-0589/
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
pinn_in: Float[Array, " obs_batch_size input_dim"]
|
|
20
|
+
val: Float[Array, " obs_batch_size output_dim"]
|
|
21
|
+
eq_params: dict[str, Float[Array, " obs_batch_size 1"]]
|
|
22
|
+
|
|
23
|
+
|
|
5
24
|
class ODEBatch(eqx.Module):
|
|
6
|
-
temporal_batch: Float[Array, "batch_size"]
|
|
7
|
-
param_batch_dict: dict = eqx.field(default=None)
|
|
8
|
-
obs_batch_dict:
|
|
25
|
+
temporal_batch: Float[Array, " batch_size"]
|
|
26
|
+
param_batch_dict: dict[str, Array] = eqx.field(default=None)
|
|
27
|
+
obs_batch_dict: ObsBatchDict = eqx.field(default=None)
|
|
9
28
|
|
|
10
29
|
|
|
11
30
|
class PDENonStatioBatch(eqx.Module):
|
|
12
|
-
domain_batch: Float[Array, "batch_size 1+dimension"]
|
|
13
|
-
border_batch: Float[Array, "batch_size dimension n_facets"]
|
|
14
|
-
initial_batch: Float[Array, "batch_size dimension"]
|
|
15
|
-
param_batch_dict: dict = eqx.field(default=None)
|
|
16
|
-
obs_batch_dict:
|
|
31
|
+
domain_batch: Float[Array, " batch_size 1+dimension"]
|
|
32
|
+
border_batch: Float[Array, " batch_size dimension n_facets"] | None
|
|
33
|
+
initial_batch: Float[Array, " batch_size dimension"] | None
|
|
34
|
+
param_batch_dict: dict[str, Array] = eqx.field(default=None)
|
|
35
|
+
obs_batch_dict: ObsBatchDict = eqx.field(default=None)
|
|
17
36
|
|
|
18
37
|
|
|
19
38
|
class PDEStatioBatch(eqx.Module):
|
|
20
|
-
domain_batch: Float[Array, "batch_size dimension"]
|
|
21
|
-
border_batch: Float[Array, "batch_size dimension n_facets"]
|
|
22
|
-
param_batch_dict: dict = eqx.field(default=None)
|
|
23
|
-
obs_batch_dict:
|
|
39
|
+
domain_batch: Float[Array, " batch_size dimension"]
|
|
40
|
+
border_batch: Float[Array, " batch_size dimension n_facets"] | None
|
|
41
|
+
param_batch_dict: dict[str, Array] = eqx.field(default=None)
|
|
42
|
+
obs_batch_dict: ObsBatchDict = eqx.field(default=None)
|
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Define the DataGenerators modules
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
import warnings
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jaxtyping import Key, Array, Float
|
|
13
|
+
from jinns.data._Batchs import PDENonStatioBatch
|
|
14
|
+
from jinns.data._utils import (
|
|
15
|
+
make_cartesian_product,
|
|
16
|
+
_check_and_set_rar_parameters,
|
|
17
|
+
_reset_or_increment,
|
|
18
|
+
)
|
|
19
|
+
from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
23
|
+
r"""
|
|
24
|
+
A class implementing data generator object for non stationary partial
|
|
25
|
+
differential equations. Formally, it extends `CubicMeshPDEStatio`
|
|
26
|
+
to include a temporal batch.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
key : Key
|
|
31
|
+
Jax random key to sample new time points and to shuffle batches
|
|
32
|
+
n : int
|
|
33
|
+
The number of total $I\times \Omega$ points that will be divided in
|
|
34
|
+
batches. Batches are made so that each data point is seen only
|
|
35
|
+
once during 1 epoch.
|
|
36
|
+
nb : int | None
|
|
37
|
+
The total number of points in $\partial\Omega$. Can be None if no
|
|
38
|
+
boundary condition is specified.
|
|
39
|
+
ni : int
|
|
40
|
+
The number of total $\Omega$ points at $t=0$ that will be divided in
|
|
41
|
+
batches. Batches are made so that each data point is seen only
|
|
42
|
+
once during 1 epoch.
|
|
43
|
+
domain_batch_size : int | None, default=None
|
|
44
|
+
The size of the batch of randomly selected points among
|
|
45
|
+
the `n` points. If None no mini-batches are used.
|
|
46
|
+
border_batch_size : int | None, default=None
|
|
47
|
+
The size of the batch of points randomly selected
|
|
48
|
+
among the `nb` points. If None, `domain_batch_size` no
|
|
49
|
+
mini-batches are used.
|
|
50
|
+
initial_batch_size : int | None, default=None
|
|
51
|
+
The size of the batch of randomly selected points among
|
|
52
|
+
the `ni` points. If None no
|
|
53
|
+
mini-batches are used.
|
|
54
|
+
dim : int
|
|
55
|
+
An integer. Dimension of $\Omega$ domain.
|
|
56
|
+
min_pts : tuple[tuple[Float, Float], ...]
|
|
57
|
+
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
58
|
+
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
59
|
+
x_{n, min})$
|
|
60
|
+
max_pts : tuple[tuple[Float, Float], ...]
|
|
61
|
+
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
62
|
+
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
63
|
+
x_{n,max})$
|
|
64
|
+
tmin : float
|
|
65
|
+
The minimum value of the time domain to consider
|
|
66
|
+
tmax : float
|
|
67
|
+
The maximum value of the time domain to consider
|
|
68
|
+
method : str, default="uniform"
|
|
69
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
70
|
+
The method that generates the `nt` time points. `grid` means
|
|
71
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
72
|
+
sampled points over the domain
|
|
73
|
+
rar_parameters : Dict[str, int], default=None
|
|
74
|
+
Defaults to None: do not use Residual Adaptative Resampling.
|
|
75
|
+
Otherwise a dictionary with keys
|
|
76
|
+
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
77
|
+
- `update_every`: the number of gradient steps taken between
|
|
78
|
+
each update of collocation points in the RAR algo.
|
|
79
|
+
- `sample_size`: the size of the sample from which we will select new
|
|
80
|
+
collocation points.
|
|
81
|
+
- `selected_sample_size`: the number of selected
|
|
82
|
+
points from the sample to be added to the current collocation
|
|
83
|
+
points.
|
|
84
|
+
n_start : int, default=None
|
|
85
|
+
Defaults to None. The effective size of n used at start time.
|
|
86
|
+
This value must be
|
|
87
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
88
|
+
n_start = n and this is hidden from the user.
|
|
89
|
+
In RAR, n_start
|
|
90
|
+
then corresponds to the initial number of omega points we train the PINN.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
tmin: Float = eqx.field(kw_only=True)
|
|
94
|
+
tmax: Float = eqx.field(kw_only=True)
|
|
95
|
+
ni: int = eqx.field(kw_only=True, static=True)
|
|
96
|
+
domain_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
97
|
+
initial_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
98
|
+
border_batch_size: int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
99
|
+
|
|
100
|
+
curr_domain_idx: int = eqx.field(init=False)
|
|
101
|
+
curr_initial_idx: int = eqx.field(init=False)
|
|
102
|
+
curr_border_idx: int = eqx.field(init=False)
|
|
103
|
+
domain: Float[Array, " n 1+dim"] = eqx.field(init=False)
|
|
104
|
+
border: Float[Array, " (nb//2) 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None = (
|
|
105
|
+
eqx.field(init=False)
|
|
106
|
+
)
|
|
107
|
+
initial: Float[Array, " ni dim"] | None = eqx.field(init=False)
|
|
108
|
+
|
|
109
|
+
def __post_init__(self):
|
|
110
|
+
"""
|
|
111
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
112
|
+
Module with eqx.tree_at!
|
|
113
|
+
"""
|
|
114
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
115
|
+
# class is not automatically called
|
|
116
|
+
|
|
117
|
+
if self.method == "grid":
|
|
118
|
+
# NOTE we must redo the sampling with the square root number of samples
|
|
119
|
+
# and then take the cartesian product
|
|
120
|
+
self.n = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
121
|
+
if self.dim == 2:
|
|
122
|
+
# in the case of grid sampling in 2D in dim 2 in non-statio,
|
|
123
|
+
# self.n needs to be a perfect ^4, because there is the
|
|
124
|
+
# cartesian product with time domain which is also present
|
|
125
|
+
perfect_4 = int(jnp.round(self.n**0.25) ** 4)
|
|
126
|
+
if self.n != perfect_4:
|
|
127
|
+
warnings.warn(
|
|
128
|
+
"Grid sampling is requested in dimension 2 in non"
|
|
129
|
+
" stationary setting with a non"
|
|
130
|
+
f" perfect square dataset size (self.n = {self.n})."
|
|
131
|
+
f" Modifying self.n to self.n = {perfect_4}."
|
|
132
|
+
)
|
|
133
|
+
self.n = perfect_4
|
|
134
|
+
self.key, half_domain_times = self.generate_time_data(
|
|
135
|
+
self.key, int(jnp.round(jnp.sqrt(self.n)))
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.key, half_domain_omega = self.generate_omega_data(
|
|
139
|
+
self.key, data_size=int(jnp.round(jnp.sqrt(self.n)))
|
|
140
|
+
)
|
|
141
|
+
self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
|
|
142
|
+
|
|
143
|
+
# NOTE
|
|
144
|
+
(
|
|
145
|
+
self.n_start,
|
|
146
|
+
self.p,
|
|
147
|
+
self.rar_iter_from_last_sampling,
|
|
148
|
+
self.rar_iter_nb,
|
|
149
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
150
|
+
elif self.method == "uniform":
|
|
151
|
+
self.key, domain_times = self.generate_time_data(self.key, self.n)
|
|
152
|
+
self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if self.domain_batch_size is None:
|
|
159
|
+
self.curr_domain_idx = 0
|
|
160
|
+
else:
|
|
161
|
+
self.curr_domain_idx = self.n + self.domain_batch_size
|
|
162
|
+
# to be sure there is a shuffling at first get_batch()
|
|
163
|
+
if self.nb is not None:
|
|
164
|
+
assert (
|
|
165
|
+
self.omega_border is not None
|
|
166
|
+
) # this needs to have been instanciated in super.__post_init__()
|
|
167
|
+
# the check below has already been done in super.__post_init__ if
|
|
168
|
+
# dim > 1. Here we retest it in whatever dim
|
|
169
|
+
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"number of border point must be"
|
|
172
|
+
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
173
|
+
)
|
|
174
|
+
# the check below concern omega_border_batch_size for dim > 1 in
|
|
175
|
+
# super.__post_init__. Here it concerns all dim values since our
|
|
176
|
+
# border_batch is the concatenation or cartesian product with times
|
|
177
|
+
if (
|
|
178
|
+
self.border_batch_size is not None
|
|
179
|
+
and self.nb // (2 * self.dim) < self.border_batch_size
|
|
180
|
+
):
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"number of points per facets (nb//2*self.dim)"
|
|
183
|
+
" cannot be lower than border batch size"
|
|
184
|
+
)
|
|
185
|
+
self.key, boundary_times = self.generate_time_data(
|
|
186
|
+
self.key, self.nb // (2 * self.dim)
|
|
187
|
+
)
|
|
188
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
189
|
+
boundary_times = jnp.repeat(
|
|
190
|
+
boundary_times, self.omega_border.shape[-1], axis=2
|
|
191
|
+
)
|
|
192
|
+
if self.dim == 1:
|
|
193
|
+
self.border = make_cartesian_product(
|
|
194
|
+
boundary_times, self.omega_border[None, None]
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
self.border = jnp.concatenate(
|
|
198
|
+
[boundary_times, self.omega_border], axis=1
|
|
199
|
+
)
|
|
200
|
+
if self.border_batch_size is None:
|
|
201
|
+
self.curr_border_idx = 0
|
|
202
|
+
else:
|
|
203
|
+
self.curr_border_idx = self.nb + self.border_batch_size
|
|
204
|
+
# to be sure there is a shuffling at first get_batch()
|
|
205
|
+
|
|
206
|
+
else:
|
|
207
|
+
self.border = None
|
|
208
|
+
self.border_batch_size = None
|
|
209
|
+
self.curr_border_idx = 0
|
|
210
|
+
|
|
211
|
+
if self.ni is not None:
|
|
212
|
+
perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
|
|
213
|
+
if self.ni != perfect_sq:
|
|
214
|
+
warnings.warn(
|
|
215
|
+
"Grid sampling is requested in dimension 2 with a non"
|
|
216
|
+
f" perfect square dataset size (self.ni = {self.ni})."
|
|
217
|
+
f" Modifying self.ni to self.ni = {perfect_sq}."
|
|
218
|
+
)
|
|
219
|
+
self.ni = perfect_sq
|
|
220
|
+
self.key, self.initial = self.generate_omega_data(
|
|
221
|
+
self.key, data_size=self.ni
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if self.initial_batch_size is None or self.initial_batch_size == self.ni:
|
|
225
|
+
self.curr_initial_idx = 0
|
|
226
|
+
else:
|
|
227
|
+
self.curr_initial_idx = self.ni + self.initial_batch_size
|
|
228
|
+
# to be sure there is a shuffling at first get_batch()
|
|
229
|
+
else:
|
|
230
|
+
self.initial = None
|
|
231
|
+
self.initial_batch_size = None
|
|
232
|
+
|
|
233
|
+
# the following attributes will not be used anymore
|
|
234
|
+
self.omega = None # type: ignore
|
|
235
|
+
self.omega_border = None
|
|
236
|
+
|
|
237
|
+
def generate_time_data(
|
|
238
|
+
self, key: Key, nt: int
|
|
239
|
+
) -> tuple[Key, Float[Array, " nt 1"]]:
|
|
240
|
+
"""
|
|
241
|
+
Construct a complete set of `nt` time points according to the
|
|
242
|
+
specified `self.method`
|
|
243
|
+
"""
|
|
244
|
+
key, subkey = jax.random.split(key, 2)
|
|
245
|
+
if self.method == "grid":
|
|
246
|
+
partial_times = (self.tmax - self.tmin) / nt
|
|
247
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
248
|
+
if self.method == "uniform":
|
|
249
|
+
return key, self.sample_in_time_domain(subkey, nt)
|
|
250
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
251
|
+
|
|
252
|
+
def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
|
|
253
|
+
return jax.random.uniform(
|
|
254
|
+
key,
|
|
255
|
+
(nt, 1),
|
|
256
|
+
minval=self.tmin,
|
|
257
|
+
maxval=self.tmax,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def _get_domain_operands(
|
|
261
|
+
self,
|
|
262
|
+
) -> tuple[Key, Float[Array, " n 1+dim"], int, int | None, Array | None]:
|
|
263
|
+
return (
|
|
264
|
+
self.key,
|
|
265
|
+
self.domain,
|
|
266
|
+
self.curr_domain_idx,
|
|
267
|
+
self.domain_batch_size,
|
|
268
|
+
self.p,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def domain_batch(
|
|
272
|
+
self,
|
|
273
|
+
) -> tuple[CubicMeshPDENonStatio, Float[Array, " domain_batch_size 1+dim"]]:
|
|
274
|
+
if self.domain_batch_size is None or self.domain_batch_size == self.n:
|
|
275
|
+
# Avoid unnecessary reshuffling
|
|
276
|
+
return self, self.domain
|
|
277
|
+
|
|
278
|
+
bstart = self.curr_domain_idx
|
|
279
|
+
bend = bstart + self.domain_batch_size
|
|
280
|
+
|
|
281
|
+
# Compute the effective number of used collocation points
|
|
282
|
+
if self.rar_parameters is not None:
|
|
283
|
+
n_eff = (
|
|
284
|
+
self.n_start
|
|
285
|
+
+ self.rar_iter_nb # type: ignore
|
|
286
|
+
* self.rar_parameters["selected_sample_size"]
|
|
287
|
+
)
|
|
288
|
+
else:
|
|
289
|
+
n_eff = self.n
|
|
290
|
+
|
|
291
|
+
new_attributes = _reset_or_increment(
|
|
292
|
+
bend,
|
|
293
|
+
n_eff,
|
|
294
|
+
self._get_domain_operands(), # type: ignore
|
|
295
|
+
# ignore since the case self.domain_batch_size is None has been
|
|
296
|
+
# handled above
|
|
297
|
+
)
|
|
298
|
+
new = eqx.tree_at(
|
|
299
|
+
lambda m: (m.key, m.domain, m.curr_domain_idx),
|
|
300
|
+
self,
|
|
301
|
+
new_attributes,
|
|
302
|
+
)
|
|
303
|
+
return new, jax.lax.dynamic_slice(
|
|
304
|
+
new.domain,
|
|
305
|
+
start_indices=(new.curr_domain_idx, 0),
|
|
306
|
+
slice_sizes=(new.domain_batch_size, new.dim + 1),
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def _get_border_operands(
|
|
310
|
+
self,
|
|
311
|
+
) -> tuple[
|
|
312
|
+
Key,
|
|
313
|
+
Float[Array, " nb 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None,
|
|
314
|
+
int,
|
|
315
|
+
int | None,
|
|
316
|
+
None,
|
|
317
|
+
]:
|
|
318
|
+
return (
|
|
319
|
+
self.key,
|
|
320
|
+
self.border,
|
|
321
|
+
self.curr_border_idx,
|
|
322
|
+
self.border_batch_size,
|
|
323
|
+
None,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def border_batch(
|
|
327
|
+
self,
|
|
328
|
+
) -> tuple[
|
|
329
|
+
CubicMeshPDENonStatio,
|
|
330
|
+
Float[Array, " border_batch_size 1+1 2"]
|
|
331
|
+
| Float[Array, " border_batch_size 2+1 4"]
|
|
332
|
+
| None,
|
|
333
|
+
]:
|
|
334
|
+
if self.nb is None or self.border is None:
|
|
335
|
+
# Avoid unnecessary reshuffling
|
|
336
|
+
return self, None
|
|
337
|
+
|
|
338
|
+
if (
|
|
339
|
+
self.border_batch_size is None
|
|
340
|
+
or self.border_batch_size == self.nb // 2**self.dim
|
|
341
|
+
):
|
|
342
|
+
# Avoid unnecessary reshuffling
|
|
343
|
+
return self, self.border
|
|
344
|
+
|
|
345
|
+
bstart = self.curr_border_idx
|
|
346
|
+
bend = bstart + self.border_batch_size
|
|
347
|
+
|
|
348
|
+
n_eff = self.border.shape[0]
|
|
349
|
+
|
|
350
|
+
new_attributes = _reset_or_increment(
|
|
351
|
+
bend,
|
|
352
|
+
n_eff,
|
|
353
|
+
self._get_border_operands(), # type: ignore
|
|
354
|
+
# ignore since the case self.border_batch_size is None has been
|
|
355
|
+
# handled above
|
|
356
|
+
)
|
|
357
|
+
new = eqx.tree_at(
|
|
358
|
+
lambda m: (m.key, m.border, m.curr_border_idx),
|
|
359
|
+
self,
|
|
360
|
+
new_attributes,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
return new, jax.lax.dynamic_slice(
|
|
364
|
+
new.border,
|
|
365
|
+
start_indices=(new.curr_border_idx, 0, 0),
|
|
366
|
+
slice_sizes=(
|
|
367
|
+
new.border_batch_size,
|
|
368
|
+
new.dim + 1,
|
|
369
|
+
2 * new.dim,
|
|
370
|
+
),
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def _get_initial_operands(
|
|
374
|
+
self,
|
|
375
|
+
) -> tuple[Key, Float[Array, " ni dim"] | None, int, int | None, None]:
|
|
376
|
+
return (
|
|
377
|
+
self.key,
|
|
378
|
+
self.initial,
|
|
379
|
+
self.curr_initial_idx,
|
|
380
|
+
self.initial_batch_size,
|
|
381
|
+
None,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
def initial_batch(
|
|
385
|
+
self,
|
|
386
|
+
) -> tuple[CubicMeshPDENonStatio, Float[Array, " initial_batch_size dim"] | None]:
|
|
387
|
+
if self.initial_batch_size is None or self.initial_batch_size == self.ni:
|
|
388
|
+
# Avoid unnecessary reshuffling
|
|
389
|
+
return self, self.initial
|
|
390
|
+
|
|
391
|
+
bstart = self.curr_initial_idx
|
|
392
|
+
bend = bstart + self.initial_batch_size
|
|
393
|
+
|
|
394
|
+
n_eff = self.ni
|
|
395
|
+
|
|
396
|
+
new_attributes = _reset_or_increment(
|
|
397
|
+
bend,
|
|
398
|
+
n_eff,
|
|
399
|
+
self._get_initial_operands(), # type: ignore
|
|
400
|
+
# ignore since the case self.initial_batch_size is None has been
|
|
401
|
+
# handled above
|
|
402
|
+
)
|
|
403
|
+
new = eqx.tree_at(
|
|
404
|
+
lambda m: (m.key, m.initial, m.curr_initial_idx),
|
|
405
|
+
self,
|
|
406
|
+
new_attributes,
|
|
407
|
+
)
|
|
408
|
+
return new, jax.lax.dynamic_slice(
|
|
409
|
+
new.initial,
|
|
410
|
+
start_indices=(new.curr_initial_idx, 0),
|
|
411
|
+
slice_sizes=(new.initial_batch_size, new.dim),
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
def get_batch(self) -> tuple[CubicMeshPDENonStatio, PDENonStatioBatch]:
|
|
415
|
+
"""
|
|
416
|
+
Generic method to return a batch. Here we call `self.domain_batch()`,
|
|
417
|
+
`self.border_batch()` and `self.initial_batch()`
|
|
418
|
+
"""
|
|
419
|
+
new, domain = self.domain_batch()
|
|
420
|
+
if self.border is not None:
|
|
421
|
+
new, border = new.border_batch()
|
|
422
|
+
else:
|
|
423
|
+
border = None
|
|
424
|
+
if self.initial is not None:
|
|
425
|
+
new, initial = new.initial_batch()
|
|
426
|
+
else:
|
|
427
|
+
initial = None
|
|
428
|
+
|
|
429
|
+
return new, PDENonStatioBatch(
|
|
430
|
+
domain_batch=domain, border_batch=border, initial_batch=initial
|
|
431
|
+
)
|