jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Define the DataGenerators modules
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
import warnings
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jaxtyping import Key, Array, Float
|
|
13
|
+
from jinns.data._Batchs import PDEStatioBatch
|
|
14
|
+
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
15
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
19
|
+
r"""
|
|
20
|
+
A class implementing data generator object for stationary partial
|
|
21
|
+
differential equations.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
key : Key
|
|
26
|
+
Jax random key to sample new time points and to shuffle batches
|
|
27
|
+
n : int
|
|
28
|
+
The number of total $\Omega$ points that will be divided in
|
|
29
|
+
batches. Batches are made so that each data point is seen only
|
|
30
|
+
once during 1 epoch.
|
|
31
|
+
nb : int | None
|
|
32
|
+
The total number of points in $\partial\Omega$. Can be None if no
|
|
33
|
+
boundary condition is specified.
|
|
34
|
+
omega_batch_size : int | None, default=None
|
|
35
|
+
The size of the batch of randomly selected points among
|
|
36
|
+
the `n` points. If None no minibatches are used.
|
|
37
|
+
omega_border_batch_size : int | None, default=None
|
|
38
|
+
The size of the batch of points randomly selected
|
|
39
|
+
among the `nb` points. If None, `omega_border_batch_size`
|
|
40
|
+
no minibatches are used. In dimension 1,
|
|
41
|
+
minibatches are never used since the boundary is composed of two
|
|
42
|
+
singletons.
|
|
43
|
+
dim : int
|
|
44
|
+
Dimension of $\Omega$ domain
|
|
45
|
+
min_pts : tuple[tuple[Float, Float], ...]
|
|
46
|
+
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
47
|
+
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
48
|
+
x_{n, min})$
|
|
49
|
+
max_pts : tuple[tuple[Float, Float], ...]
|
|
50
|
+
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
51
|
+
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
52
|
+
x_{n,max})$
|
|
53
|
+
method : str, default="uniform"
|
|
54
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
55
|
+
The method that generates the `nt` time points. `grid` means
|
|
56
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
57
|
+
sampled points over the domain
|
|
58
|
+
rar_parameters : dict[str, int], default=None
|
|
59
|
+
Defaults to None: do not use Residual Adaptative Resampling.
|
|
60
|
+
Otherwise a dictionary with keys
|
|
61
|
+
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
62
|
+
- `update_every`: the number of gradient steps taken between
|
|
63
|
+
each update of collocation points in the RAR algo.
|
|
64
|
+
- `sample_size`: the size of the sample from which we will select new
|
|
65
|
+
collocation points.
|
|
66
|
+
- `selected_sample_size`: the number of selected
|
|
67
|
+
points from the sample to be added to the current collocation
|
|
68
|
+
points.
|
|
69
|
+
n_start : int, default=None
|
|
70
|
+
Defaults to None. The effective size of n used at start time.
|
|
71
|
+
This value must be
|
|
72
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
73
|
+
n_start = n and this is hidden from the user.
|
|
74
|
+
In RAR, n_start
|
|
75
|
+
then corresponds to the initial number of points we train the PINN on.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
79
|
+
key: Key = eqx.field(kw_only=True)
|
|
80
|
+
n: int = eqx.field(kw_only=True, static=True)
|
|
81
|
+
nb: int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
82
|
+
omega_batch_size: int | None = eqx.field(
|
|
83
|
+
kw_only=True,
|
|
84
|
+
static=True,
|
|
85
|
+
default=None, # can be None as
|
|
86
|
+
# CubicMeshPDENonStatio inherits but also if omega_batch_size=n
|
|
87
|
+
) # static cause used as a
|
|
88
|
+
# shape in jax.lax.dynamic_slice
|
|
89
|
+
omega_border_batch_size: int | None = eqx.field(
|
|
90
|
+
kw_only=True, static=True, default=None
|
|
91
|
+
) # static cause used as a
|
|
92
|
+
# shape in jax.lax.dynamic_slice
|
|
93
|
+
dim: int = eqx.field(kw_only=True, static=True) # static cause used as a
|
|
94
|
+
# shape in jax.lax.dynamic_slice
|
|
95
|
+
min_pts: tuple[float, ...] = eqx.field(kw_only=True)
|
|
96
|
+
max_pts: tuple[float, ...] = eqx.field(kw_only=True)
|
|
97
|
+
method: str = eqx.field(
|
|
98
|
+
kw_only=True, static=True, default_factory=lambda: "uniform"
|
|
99
|
+
)
|
|
100
|
+
rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
|
|
101
|
+
n_start: int = eqx.field(kw_only=True, default=None, static=True)
|
|
102
|
+
|
|
103
|
+
# all the init=False fields are set in __post_init__
|
|
104
|
+
p: Float[Array, " n"] | None = eqx.field(init=False)
|
|
105
|
+
rar_iter_from_last_sampling: int | None = eqx.field(init=False)
|
|
106
|
+
rar_iter_nb: int | None = eqx.field(init=False)
|
|
107
|
+
curr_omega_idx: int = eqx.field(init=False)
|
|
108
|
+
curr_omega_border_idx: int = eqx.field(init=False)
|
|
109
|
+
omega: Float[Array, " n dim"] = eqx.field(init=False)
|
|
110
|
+
omega_border: Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None = (
|
|
111
|
+
eqx.field(init=False)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def __post_init__(self):
|
|
115
|
+
assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
|
|
116
|
+
assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
|
|
117
|
+
|
|
118
|
+
(
|
|
119
|
+
self.n_start,
|
|
120
|
+
self.p,
|
|
121
|
+
self.rar_iter_from_last_sampling,
|
|
122
|
+
self.rar_iter_nb,
|
|
123
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
124
|
+
|
|
125
|
+
if self.method == "grid" and self.dim == 2:
|
|
126
|
+
perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
127
|
+
if self.n != perfect_sq:
|
|
128
|
+
warnings.warn(
|
|
129
|
+
"Grid sampling is requested in dimension 2 with a non"
|
|
130
|
+
f" perfect square dataset size (self.n = {self.n})."
|
|
131
|
+
f" Modifying self.n to self.n = {perfect_sq}."
|
|
132
|
+
)
|
|
133
|
+
self.n = perfect_sq
|
|
134
|
+
|
|
135
|
+
if self.omega_batch_size is None:
|
|
136
|
+
self.curr_omega_idx = 0
|
|
137
|
+
else:
|
|
138
|
+
self.curr_omega_idx = self.n + self.omega_batch_size
|
|
139
|
+
# to be sure there is a shuffling at first get_batch()
|
|
140
|
+
|
|
141
|
+
if self.nb is not None:
|
|
142
|
+
if self.dim == 1:
|
|
143
|
+
self.omega_border_batch_size = None
|
|
144
|
+
# We are in 1-D case => omega_border_batch_size is
|
|
145
|
+
# ignored since borders of Omega are singletons.
|
|
146
|
+
# self.border_batch() will return [xmin, xmax]
|
|
147
|
+
else:
|
|
148
|
+
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"number of border point must be"
|
|
151
|
+
f" a multiple of 2xd = {2 * self.dim} (the # of faces of"
|
|
152
|
+
f" a d-dimensional cube). Got {self.nb=}."
|
|
153
|
+
)
|
|
154
|
+
if (
|
|
155
|
+
self.omega_border_batch_size is not None
|
|
156
|
+
and self.nb // (2 * self.dim) < self.omega_border_batch_size
|
|
157
|
+
):
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"number of points per facets ({self.nb // (2 * self.dim)})"
|
|
160
|
+
f" cannot be lower than border batch size "
|
|
161
|
+
f" ({self.omega_border_batch_size})."
|
|
162
|
+
)
|
|
163
|
+
self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
|
|
164
|
+
|
|
165
|
+
if self.omega_border_batch_size is None:
|
|
166
|
+
self.curr_omega_border_idx = 0
|
|
167
|
+
else:
|
|
168
|
+
self.curr_omega_border_idx = self.nb + self.omega_border_batch_size
|
|
169
|
+
# to be sure there is a shuffling at first get_batch()
|
|
170
|
+
else: # self.nb is None
|
|
171
|
+
self.curr_omega_border_idx = 0
|
|
172
|
+
|
|
173
|
+
self.key, self.omega = self.generate_omega_data(self.key)
|
|
174
|
+
self.key, self.omega_border = self.generate_omega_border_data(self.key)
|
|
175
|
+
|
|
176
|
+
def sample_in_omega_domain(
|
|
177
|
+
self, keys: Key, sample_size: int
|
|
178
|
+
) -> Float[Array, " n dim"]:
|
|
179
|
+
if self.dim == 1:
|
|
180
|
+
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
181
|
+
return jax.random.uniform(
|
|
182
|
+
keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
183
|
+
)
|
|
184
|
+
# keys = jax.random.split(key, self.dim)
|
|
185
|
+
return jnp.concatenate(
|
|
186
|
+
[
|
|
187
|
+
jax.random.uniform(
|
|
188
|
+
keys[i],
|
|
189
|
+
(sample_size, 1),
|
|
190
|
+
minval=self.min_pts[i],
|
|
191
|
+
maxval=self.max_pts[i],
|
|
192
|
+
)
|
|
193
|
+
for i in range(self.dim)
|
|
194
|
+
],
|
|
195
|
+
axis=-1,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def sample_in_omega_border_domain(
|
|
199
|
+
self, keys: Key, sample_size: int | None = None
|
|
200
|
+
) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
|
|
201
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
202
|
+
if sample_size is None:
|
|
203
|
+
return None
|
|
204
|
+
if self.dim == 1:
|
|
205
|
+
xmin = self.min_pts[0]
|
|
206
|
+
xmax = self.max_pts[0]
|
|
207
|
+
return jnp.array([xmin, xmax]).astype(float)
|
|
208
|
+
if self.dim == 2:
|
|
209
|
+
# currently hard-coded the 4 edges for d==2
|
|
210
|
+
# TODO : find a general & efficient way to sample from the border
|
|
211
|
+
# (facets) of the hypercube in general dim.
|
|
212
|
+
facet_n = sample_size // (2 * self.dim)
|
|
213
|
+
xmin = jnp.hstack(
|
|
214
|
+
[
|
|
215
|
+
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
216
|
+
jax.random.uniform(
|
|
217
|
+
keys[0],
|
|
218
|
+
(facet_n, 1),
|
|
219
|
+
minval=self.min_pts[1],
|
|
220
|
+
maxval=self.max_pts[1],
|
|
221
|
+
),
|
|
222
|
+
]
|
|
223
|
+
)
|
|
224
|
+
xmax = jnp.hstack(
|
|
225
|
+
[
|
|
226
|
+
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
227
|
+
jax.random.uniform(
|
|
228
|
+
keys[1],
|
|
229
|
+
(facet_n, 1),
|
|
230
|
+
minval=self.min_pts[1],
|
|
231
|
+
maxval=self.max_pts[1],
|
|
232
|
+
),
|
|
233
|
+
]
|
|
234
|
+
)
|
|
235
|
+
ymin = jnp.hstack(
|
|
236
|
+
[
|
|
237
|
+
jax.random.uniform(
|
|
238
|
+
keys[2],
|
|
239
|
+
(facet_n, 1),
|
|
240
|
+
minval=self.min_pts[0],
|
|
241
|
+
maxval=self.max_pts[0],
|
|
242
|
+
),
|
|
243
|
+
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
244
|
+
]
|
|
245
|
+
)
|
|
246
|
+
ymax = jnp.hstack(
|
|
247
|
+
[
|
|
248
|
+
jax.random.uniform(
|
|
249
|
+
keys[3],
|
|
250
|
+
(facet_n, 1),
|
|
251
|
+
minval=self.min_pts[0],
|
|
252
|
+
maxval=self.max_pts[0],
|
|
253
|
+
),
|
|
254
|
+
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
255
|
+
]
|
|
256
|
+
)
|
|
257
|
+
return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
258
|
+
raise NotImplementedError(
|
|
259
|
+
"Generation of the border of a cube in dimension > 2 is not "
|
|
260
|
+
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def generate_omega_data(
|
|
264
|
+
self, key: Key, data_size: int | None = None
|
|
265
|
+
) -> tuple[
|
|
266
|
+
Key,
|
|
267
|
+
Float[Array, " n dim"],
|
|
268
|
+
]:
|
|
269
|
+
r"""
|
|
270
|
+
Construct a complete set of `self.n` $\Omega$ points according to the
|
|
271
|
+
specified `self.method`.
|
|
272
|
+
"""
|
|
273
|
+
data_size = self.n if data_size is None else data_size
|
|
274
|
+
# Generate Omega
|
|
275
|
+
if self.method == "grid":
|
|
276
|
+
if self.dim == 1:
|
|
277
|
+
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
278
|
+
## shape (n, 1)
|
|
279
|
+
omega = jnp.linspace(xmin, xmax, data_size)[:, None]
|
|
280
|
+
else:
|
|
281
|
+
xyz_ = jnp.meshgrid(
|
|
282
|
+
*[
|
|
283
|
+
jnp.linspace(
|
|
284
|
+
self.min_pts[i],
|
|
285
|
+
self.max_pts[i],
|
|
286
|
+
int(jnp.round(jnp.sqrt(data_size))),
|
|
287
|
+
)
|
|
288
|
+
for i in range(self.dim)
|
|
289
|
+
]
|
|
290
|
+
)
|
|
291
|
+
xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
|
|
292
|
+
omega = jnp.concatenate(xyz_, axis=-1)
|
|
293
|
+
elif self.method == "uniform":
|
|
294
|
+
if self.dim == 1:
|
|
295
|
+
key, subkeys = jax.random.split(key, 2)
|
|
296
|
+
else:
|
|
297
|
+
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
298
|
+
omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
|
|
299
|
+
else:
|
|
300
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
301
|
+
return key, omega
|
|
302
|
+
|
|
303
|
+
def generate_omega_border_data(
|
|
304
|
+
self, key: Key, data_size: int | None = None
|
|
305
|
+
) -> tuple[
|
|
306
|
+
Key,
|
|
307
|
+
Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
|
|
308
|
+
]:
|
|
309
|
+
r"""
|
|
310
|
+
Also constructs a complete set of `self.nb`
|
|
311
|
+
$\partial\Omega$ points if `self.omega_border_batch_size` is not
|
|
312
|
+
`None`. If the latter is `None` we set `self.omega_border` to `None`.
|
|
313
|
+
"""
|
|
314
|
+
# Generate border of omega
|
|
315
|
+
data_size = self.nb if data_size is None else data_size
|
|
316
|
+
if self.dim == 2:
|
|
317
|
+
key, *subkeys = jax.random.split(key, 5)
|
|
318
|
+
else:
|
|
319
|
+
subkeys = None
|
|
320
|
+
omega_border = self.sample_in_omega_border_domain(
|
|
321
|
+
subkeys, sample_size=data_size
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return key, omega_border
|
|
325
|
+
|
|
326
|
+
def _get_omega_operands(
|
|
327
|
+
self,
|
|
328
|
+
) -> tuple[Key, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None]:
|
|
329
|
+
return (
|
|
330
|
+
self.key,
|
|
331
|
+
self.omega,
|
|
332
|
+
self.curr_omega_idx,
|
|
333
|
+
self.omega_batch_size,
|
|
334
|
+
self.p,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def inside_batch(
|
|
338
|
+
self,
|
|
339
|
+
) -> tuple[CubicMeshPDEStatio, Float[Array, " omega_batch_size dim"]]:
|
|
340
|
+
r"""
|
|
341
|
+
Return a batch of points in $\Omega$.
|
|
342
|
+
If all the batches have been seen, we reshuffle them,
|
|
343
|
+
otherwise we just return the next unseen batch.
|
|
344
|
+
"""
|
|
345
|
+
if self.omega_batch_size is None or self.omega_batch_size == self.n:
|
|
346
|
+
# Avoid unnecessary reshuffling
|
|
347
|
+
return self, self.omega
|
|
348
|
+
|
|
349
|
+
# Compute the effective number of used collocation points
|
|
350
|
+
if self.rar_parameters is not None:
|
|
351
|
+
n_eff = (
|
|
352
|
+
self.n_start
|
|
353
|
+
+ self.rar_iter_nb # type: ignore
|
|
354
|
+
* self.rar_parameters["selected_sample_size"]
|
|
355
|
+
)
|
|
356
|
+
else:
|
|
357
|
+
n_eff = self.n
|
|
358
|
+
|
|
359
|
+
bstart = self.curr_omega_idx
|
|
360
|
+
bend = bstart + self.omega_batch_size
|
|
361
|
+
|
|
362
|
+
new_attributes = _reset_or_increment(
|
|
363
|
+
bend,
|
|
364
|
+
n_eff,
|
|
365
|
+
self._get_omega_operands(), # type: ignore
|
|
366
|
+
# ignore since the case self.omega_batch_size is None has been
|
|
367
|
+
# handled above
|
|
368
|
+
)
|
|
369
|
+
new = eqx.tree_at(
|
|
370
|
+
lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
return new, jax.lax.dynamic_slice(
|
|
374
|
+
new.omega,
|
|
375
|
+
start_indices=(new.curr_omega_idx, 0),
|
|
376
|
+
slice_sizes=(new.omega_batch_size, new.dim),
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def _get_omega_border_operands(
|
|
380
|
+
self,
|
|
381
|
+
) -> tuple[
|
|
382
|
+
Key,
|
|
383
|
+
Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
|
|
384
|
+
int,
|
|
385
|
+
int | None,
|
|
386
|
+
None,
|
|
387
|
+
]:
|
|
388
|
+
return (
|
|
389
|
+
self.key,
|
|
390
|
+
self.omega_border,
|
|
391
|
+
self.curr_omega_border_idx,
|
|
392
|
+
self.omega_border_batch_size,
|
|
393
|
+
None,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def border_batch(
|
|
397
|
+
self,
|
|
398
|
+
) -> tuple[
|
|
399
|
+
CubicMeshPDEStatio,
|
|
400
|
+
Float[Array, " 1 1 2"] | Float[Array, " omega_border_batch_size 2 4"] | None,
|
|
401
|
+
]:
|
|
402
|
+
r"""
|
|
403
|
+
Return
|
|
404
|
+
|
|
405
|
+
- The value `None` if `self.omega_border_batch_size` is `None`.
|
|
406
|
+
|
|
407
|
+
- a jnp array with two fixed values $(x_{min}, x_{max})$ if
|
|
408
|
+
`self.dim` = 1. There is no sampling here, we return the entire
|
|
409
|
+
$\partial\Omega$
|
|
410
|
+
|
|
411
|
+
- a batch of points in $\partial\Omega$ otherwise, stacked by
|
|
412
|
+
facet on the last axis.
|
|
413
|
+
If all the batches have been seen, we reshuffle them,
|
|
414
|
+
otherwise we just return the next unseen batch.
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
"""
|
|
418
|
+
if self.nb is None or self.omega_border is None:
|
|
419
|
+
# Avoid unnecessary reshuffling
|
|
420
|
+
return self, None
|
|
421
|
+
|
|
422
|
+
if self.dim == 1:
|
|
423
|
+
# Avoid unnecessary reshuffling
|
|
424
|
+
# 1-D case, no randomness : we always return the whole omega border,
|
|
425
|
+
# i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
|
|
426
|
+
return self, self.omega_border[None, None] # shape is (1, 1, 2)
|
|
427
|
+
|
|
428
|
+
if (
|
|
429
|
+
self.omega_border_batch_size is None
|
|
430
|
+
or self.omega_border_batch_size == self.nb // 2**self.dim
|
|
431
|
+
):
|
|
432
|
+
# Avoid unnecessary reshuffling
|
|
433
|
+
return self, self.omega_border
|
|
434
|
+
|
|
435
|
+
bstart = self.curr_omega_border_idx
|
|
436
|
+
bend = bstart + self.omega_border_batch_size
|
|
437
|
+
|
|
438
|
+
new_attributes = _reset_or_increment(
|
|
439
|
+
bend,
|
|
440
|
+
self.nb,
|
|
441
|
+
self._get_omega_border_operands(), # type: ignore
|
|
442
|
+
# ignore since the case self.omega_border_batch_size is None has been
|
|
443
|
+
# handled above
|
|
444
|
+
)
|
|
445
|
+
new = eqx.tree_at(
|
|
446
|
+
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
|
|
447
|
+
self,
|
|
448
|
+
new_attributes,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
return new, jax.lax.dynamic_slice(
|
|
452
|
+
new.omega_border,
|
|
453
|
+
start_indices=(new.curr_omega_border_idx, 0, 0),
|
|
454
|
+
slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
def get_batch(self) -> tuple[CubicMeshPDEStatio, PDEStatioBatch]:
|
|
458
|
+
"""
|
|
459
|
+
Generic method to return a batch. Here we call `self.inside_batch()`
|
|
460
|
+
and `self.border_batch()`
|
|
461
|
+
"""
|
|
462
|
+
new, inside_batch = self.inside_batch()
|
|
463
|
+
new, border_batch = new.border_batch()
|
|
464
|
+
return new, PDEStatioBatch(domain_batch=inside_batch, border_batch=border_batch)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Define the DataGenerators modules
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jaxtyping import Key, Array, Float
|
|
13
|
+
from jinns.data._Batchs import ODEBatch
|
|
14
|
+
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
15
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DataGeneratorODE(AbstractDataGenerator):
|
|
22
|
+
"""
|
|
23
|
+
A class implementing data generator object for ordinary differential equations.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
key : Key
|
|
28
|
+
Jax random key to sample new time points and to shuffle batches
|
|
29
|
+
nt : int
|
|
30
|
+
The number of total time points that will be divided in
|
|
31
|
+
batches. Batches are made so that each data point is seen only
|
|
32
|
+
once during 1 epoch.
|
|
33
|
+
tmin : float
|
|
34
|
+
The minimum value of the time domain to consider
|
|
35
|
+
tmax : float
|
|
36
|
+
The maximum value of the time domain to consider
|
|
37
|
+
temporal_batch_size : int | None, default=None
|
|
38
|
+
The size of the batch of randomly selected points among
|
|
39
|
+
the `nt` points. If None, no minibatches are used.
|
|
40
|
+
method : str, default="uniform"
|
|
41
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
42
|
+
The method that generates the `nt` time points. `grid` means
|
|
43
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
44
|
+
sampled points over the domain
|
|
45
|
+
rar_parameters : RarParameterDict, default=None
|
|
46
|
+
A TypedDict to specify the Residual Adaptative Resampling procedure. See
|
|
47
|
+
the docstring from RarParameterDict
|
|
48
|
+
n_start : int, default=None
|
|
49
|
+
Defaults to None. The effective size of nt used at start time.
|
|
50
|
+
This value must be
|
|
51
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
52
|
+
n_start = nt and this is hidden from the user.
|
|
53
|
+
In RAR, n_start
|
|
54
|
+
then corresponds to the initial number of points we train the PINN.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
key: Key = eqx.field(kw_only=True)
|
|
58
|
+
nt: int = eqx.field(kw_only=True, static=True)
|
|
59
|
+
tmin: Float = eqx.field(kw_only=True)
|
|
60
|
+
tmax: Float = eqx.field(kw_only=True)
|
|
61
|
+
temporal_batch_size: int | None = eqx.field(static=True, default=None, kw_only=True)
|
|
62
|
+
method: str = eqx.field(
|
|
63
|
+
static=True, kw_only=True, default_factory=lambda: "uniform"
|
|
64
|
+
)
|
|
65
|
+
rar_parameters: dict[str, int] = eqx.field(default=None, kw_only=True)
|
|
66
|
+
n_start: int = eqx.field(static=True, default=None, kw_only=True)
|
|
67
|
+
|
|
68
|
+
# all the init=False fields are set in __post_init__
|
|
69
|
+
p: Float[Array, " nt 1"] | None = eqx.field(init=False)
|
|
70
|
+
rar_iter_from_last_sampling: int | None = eqx.field(init=False)
|
|
71
|
+
rar_iter_nb: int | None = eqx.field(init=False)
|
|
72
|
+
curr_time_idx: int = eqx.field(init=False)
|
|
73
|
+
times: Float[Array, " nt 1"] = eqx.field(init=False)
|
|
74
|
+
|
|
75
|
+
def __post_init__(self):
|
|
76
|
+
(
|
|
77
|
+
self.n_start,
|
|
78
|
+
self.p,
|
|
79
|
+
self.rar_iter_from_last_sampling,
|
|
80
|
+
self.rar_iter_nb,
|
|
81
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.n_start)
|
|
82
|
+
|
|
83
|
+
if self.temporal_batch_size is not None:
|
|
84
|
+
self.curr_time_idx = self.nt + self.temporal_batch_size
|
|
85
|
+
# to be sure there is a shuffling at first get_batch()
|
|
86
|
+
# NOTE in the extreme case we could do:
|
|
87
|
+
# self.curr_time_idx=jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
|
|
88
|
+
# but we do not test for such extreme values. Where we subtract
|
|
89
|
+
# self.temporal_batch_size - 1 because otherwise when computing
|
|
90
|
+
# `bend` we do not want to overflow the max int32 with unwanted behaviour
|
|
91
|
+
else:
|
|
92
|
+
self.curr_time_idx = 0
|
|
93
|
+
|
|
94
|
+
self.key, self.times = self.generate_time_data(self.key)
|
|
95
|
+
# Note that, here, in __init__ (and __post_init__), this is the
|
|
96
|
+
# only place where self assignment are authorized so we do the
|
|
97
|
+
# above way for the key.
|
|
98
|
+
|
|
99
|
+
def sample_in_time_domain(
|
|
100
|
+
self, key: Key, sample_size: int | None = None
|
|
101
|
+
) -> Float[Array, " nt 1"]:
|
|
102
|
+
return jax.random.uniform(
|
|
103
|
+
key,
|
|
104
|
+
(self.nt if sample_size is None else sample_size, 1),
|
|
105
|
+
minval=self.tmin,
|
|
106
|
+
maxval=self.tmax,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, " nt"]]:
|
|
110
|
+
"""
|
|
111
|
+
Construct a complete set of `self.nt` time points according to the
|
|
112
|
+
specified `self.method`
|
|
113
|
+
|
|
114
|
+
Note that self.times has always size self.nt and not self.n_start, even
|
|
115
|
+
in RAR scheme, we must allocate all the collocation points
|
|
116
|
+
"""
|
|
117
|
+
key, subkey = jax.random.split(self.key)
|
|
118
|
+
if self.method == "grid":
|
|
119
|
+
partial_times = (self.tmax - self.tmin) / self.nt
|
|
120
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
121
|
+
if self.method == "uniform":
|
|
122
|
+
return key, self.sample_in_time_domain(subkey)
|
|
123
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
124
|
+
|
|
125
|
+
def _get_time_operands(
|
|
126
|
+
self,
|
|
127
|
+
) -> tuple[
|
|
128
|
+
Key, Float[Array, " nt 1"], int, int | None, Float[Array, " nt 1"] | None
|
|
129
|
+
]:
|
|
130
|
+
return (
|
|
131
|
+
self.key,
|
|
132
|
+
self.times,
|
|
133
|
+
self.curr_time_idx,
|
|
134
|
+
self.temporal_batch_size,
|
|
135
|
+
self.p,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def temporal_batch(
|
|
139
|
+
self,
|
|
140
|
+
) -> tuple[DataGeneratorODE, Float[Array, " temporal_batch_size"]]:
|
|
141
|
+
"""
|
|
142
|
+
Return a batch of time points. If all the batches have been seen, we
|
|
143
|
+
reshuffle them, otherwise we just return the next unseen batch.
|
|
144
|
+
"""
|
|
145
|
+
if self.temporal_batch_size is None or self.temporal_batch_size == self.nt:
|
|
146
|
+
# Avoid unnecessary reshuffling
|
|
147
|
+
return self, self.times
|
|
148
|
+
|
|
149
|
+
bstart = self.curr_time_idx
|
|
150
|
+
bend = bstart + self.temporal_batch_size
|
|
151
|
+
|
|
152
|
+
# Compute the effective number of used collocation points
|
|
153
|
+
if self.rar_parameters is not None:
|
|
154
|
+
nt_eff = (
|
|
155
|
+
self.n_start
|
|
156
|
+
+ self.rar_iter_nb # type: ignore
|
|
157
|
+
* self.rar_parameters["selected_sample_size"]
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
nt_eff = self.nt
|
|
161
|
+
|
|
162
|
+
new_attributes = _reset_or_increment(
|
|
163
|
+
bend,
|
|
164
|
+
nt_eff,
|
|
165
|
+
self._get_time_operands(), # type: ignore
|
|
166
|
+
# ignore since the case self.temporal_batch_size is None has been
|
|
167
|
+
# handled above
|
|
168
|
+
)
|
|
169
|
+
new = eqx.tree_at(
|
|
170
|
+
lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# commands below are equivalent to
|
|
174
|
+
# return self.times[i:(i+t_batch_size)]
|
|
175
|
+
# start indices can be dynamic but the slice shape is fixed
|
|
176
|
+
return new, jax.lax.dynamic_slice(
|
|
177
|
+
new.times,
|
|
178
|
+
start_indices=(new.curr_time_idx, 0),
|
|
179
|
+
slice_sizes=(new.temporal_batch_size, 1),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def get_batch(self) -> tuple[DataGeneratorODE, ODEBatch]:
|
|
183
|
+
"""
|
|
184
|
+
Generic method to return a batch. Here we call `self.temporal_batch()`
|
|
185
|
+
"""
|
|
186
|
+
new, temporal_batch = self.temporal_batch()
|
|
187
|
+
return new, ODEBatch(temporal_batch=temporal_batch)
|