jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/data/_DataGenerators.py
CHANGED
|
@@ -1,52 +1,49 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object
|
|
1
2
|
"""
|
|
2
|
-
|
|
3
|
+
Define the DataGeneratorODE equinox module
|
|
3
4
|
"""
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Dict
|
|
10
|
+
from dataclasses import InitVar
|
|
11
|
+
import equinox as eqx
|
|
12
|
+
import jax
|
|
7
13
|
import jax.numpy as jnp
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
import jax.lax
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ODEBatch(NamedTuple):
|
|
14
|
-
temporal_batch: ArrayLike
|
|
15
|
-
param_batch_dict: dict = None
|
|
16
|
-
obs_batch_dict: dict = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class PDENonStatioBatch(NamedTuple):
|
|
20
|
-
times_x_inside_batch: ArrayLike
|
|
21
|
-
times_x_border_batch: ArrayLike
|
|
22
|
-
param_batch_dict: dict = None
|
|
23
|
-
obs_batch_dict: dict = None
|
|
24
|
-
|
|
14
|
+
from jaxtyping import Key, Int, PyTree, Array, Float, Bool
|
|
15
|
+
from jinns.data._Batchs import *
|
|
25
16
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
border_batch: ArrayLike
|
|
29
|
-
param_batch_dict: dict = None
|
|
30
|
-
obs_batch_dict: dict = None
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from jinns.utils._types import *
|
|
31
19
|
|
|
32
20
|
|
|
33
|
-
def append_param_batch(batch, param_batch_dict):
|
|
21
|
+
def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
|
|
34
22
|
"""
|
|
35
23
|
Utility function that fill the param_batch_dict of a batch object with a
|
|
36
24
|
param_batch_dict
|
|
37
25
|
"""
|
|
38
|
-
return
|
|
26
|
+
return eqx.tree_at(
|
|
27
|
+
lambda m: m.param_batch_dict,
|
|
28
|
+
batch,
|
|
29
|
+
param_batch_dict,
|
|
30
|
+
is_leaf=lambda x: x is None,
|
|
31
|
+
)
|
|
39
32
|
|
|
40
33
|
|
|
41
|
-
def append_obs_batch(batch, obs_batch_dict):
|
|
34
|
+
def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
|
|
42
35
|
"""
|
|
43
36
|
Utility function that fill the obs_batch_dict of a batch object with a
|
|
44
37
|
obs_batch_dict
|
|
45
38
|
"""
|
|
46
|
-
return
|
|
39
|
+
return eqx.tree_at(
|
|
40
|
+
lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
|
|
41
|
+
)
|
|
47
42
|
|
|
48
43
|
|
|
49
|
-
def make_cartesian_product(
|
|
44
|
+
def make_cartesian_product(
|
|
45
|
+
b1: Float[Array, "batch_size dim1"], b2: Float[Array, "batch_size dim2"]
|
|
46
|
+
) -> Float[Array, "(batch_size*batch_size) (dim1+dim2)"]:
|
|
50
47
|
"""
|
|
51
48
|
Create the cartesian product of a time and a border omega batches
|
|
52
49
|
by tiling and repeating
|
|
@@ -58,29 +55,39 @@ def make_cartesian_product(b1, b2):
|
|
|
58
55
|
return jnp.concatenate([b1, b2], axis=1)
|
|
59
56
|
|
|
60
57
|
|
|
61
|
-
def _reset_batch_idx_and_permute(
|
|
58
|
+
def _reset_batch_idx_and_permute(
|
|
59
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
|
|
60
|
+
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
62
61
|
key, domain, curr_idx, _, p = operands
|
|
63
62
|
# resetting counter
|
|
64
63
|
curr_idx = 0
|
|
65
64
|
# reshuffling
|
|
66
|
-
key, subkey = random.split(key)
|
|
65
|
+
key, subkey = jax.random.split(key)
|
|
67
66
|
# domain = random.permutation(subkey, domain, axis=0, independent=False)
|
|
68
67
|
# we want that permutation = choice when p=None
|
|
69
68
|
# otherwise p is used to avoid collocation points not in nt_start
|
|
70
|
-
domain = random.choice(
|
|
69
|
+
domain = jax.random.choice(
|
|
70
|
+
subkey, domain, shape=(domain.shape[0],), replace=False, p=p
|
|
71
|
+
)
|
|
71
72
|
|
|
72
73
|
# return updated
|
|
73
74
|
return (key, domain, curr_idx)
|
|
74
75
|
|
|
75
76
|
|
|
76
|
-
def _increment_batch_idx(
|
|
77
|
+
def _increment_batch_idx(
|
|
78
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
|
|
79
|
+
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
77
80
|
key, domain, curr_idx, batch_size, _ = operands
|
|
78
81
|
# simply increases counter and get the batch
|
|
79
82
|
curr_idx += batch_size
|
|
80
83
|
return (key, domain, curr_idx)
|
|
81
84
|
|
|
82
85
|
|
|
83
|
-
def _reset_or_increment(
|
|
86
|
+
def _reset_or_increment(
|
|
87
|
+
bend: Int,
|
|
88
|
+
n_eff: Int,
|
|
89
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
|
|
90
|
+
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
84
91
|
"""
|
|
85
92
|
Factorize the code of the jax.lax.cond which checks if we have seen all the
|
|
86
93
|
batches in an epoch
|
|
@@ -109,15 +116,18 @@ def _reset_or_increment(bend, n_eff, operands):
|
|
|
109
116
|
)
|
|
110
117
|
|
|
111
118
|
|
|
112
|
-
def _check_and_set_rar_parameters(
|
|
119
|
+
def _check_and_set_rar_parameters(
|
|
120
|
+
rar_parameters: dict, n: Int, n_start: Int
|
|
121
|
+
) -> tuple[Int, Float[Array, "n"], Int, Int]:
|
|
113
122
|
if rar_parameters is not None and n_start is None:
|
|
114
123
|
raise ValueError(
|
|
115
|
-
|
|
124
|
+
"nt_start must be provided in the context of RAR sampling scheme"
|
|
116
125
|
)
|
|
126
|
+
|
|
117
127
|
if rar_parameters is not None:
|
|
118
128
|
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
119
129
|
# probability to specify non-used collocation points (i.e. points
|
|
120
|
-
# above
|
|
130
|
+
# above nt_start). Thus, p is a vector of probability of shape (nt, 1).
|
|
121
131
|
p = jnp.zeros((n,))
|
|
122
132
|
p = p.at[:n_start].set(1 / n_start)
|
|
123
133
|
# set internal counter for the number of gradient steps since the
|
|
@@ -129,6 +139,7 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
|
|
|
129
139
|
# have been added
|
|
130
140
|
rar_iter_nb = 0
|
|
131
141
|
else:
|
|
142
|
+
n_start = n
|
|
132
143
|
p = None
|
|
133
144
|
rar_iter_from_last_sampling = None
|
|
134
145
|
rar_iter_nb = None
|
|
@@ -136,109 +147,102 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
|
|
|
136
147
|
return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
|
|
137
148
|
|
|
138
149
|
|
|
139
|
-
|
|
140
|
-
# DataGenerator for ODE : only returns time_batches
|
|
141
|
-
#####################################################
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
@register_pytree_node_class
|
|
145
|
-
class DataGeneratorODE:
|
|
150
|
+
class DataGeneratorODE(eqx.Module):
|
|
146
151
|
"""
|
|
147
152
|
A class implementing data generator object for ordinary differential equations.
|
|
148
153
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
key : Key
|
|
157
|
+
Jax random key to sample new time points and to shuffle batches
|
|
158
|
+
nt : Int
|
|
159
|
+
The number of total time points that will be divided in
|
|
160
|
+
batches. Batches are made so that each data point is seen only
|
|
161
|
+
once during 1 epoch.
|
|
162
|
+
tmin : float
|
|
163
|
+
The minimum value of the time domain to consider
|
|
164
|
+
tmax : float
|
|
165
|
+
The maximum value of the time domain to consider
|
|
166
|
+
temporal_batch_size : int
|
|
167
|
+
The size of the batch of randomly selected points among
|
|
168
|
+
the `nt` points.
|
|
169
|
+
method : str, default="uniform"
|
|
170
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
171
|
+
The method that generates the `nt` time points. `grid` means
|
|
172
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
173
|
+
sampled points over the domain
|
|
174
|
+
rar_parameters : Dict[str, Int], default=None
|
|
175
|
+
Default to None: do not use Residual Adaptative Resampling.
|
|
176
|
+
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
177
|
+
which we start the RAR sampling scheme (we first have a burn in
|
|
178
|
+
period). `update_rate`: the number of gradient steps taken between
|
|
179
|
+
each appending of collocation points in the RAR algo.
|
|
180
|
+
`sample_size`: the size of the sample from which we will select new
|
|
181
|
+
collocation points. `selected_sample_size_times`: the number of selected
|
|
182
|
+
points from the sample to be added to the current collocation
|
|
183
|
+
points
|
|
184
|
+
nt_start : Int, default=None
|
|
185
|
+
Defaults to None. The effective size of nt used at start time.
|
|
186
|
+
This value must be
|
|
187
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
188
|
+
nt_start = nt and this is hidden from the user.
|
|
189
|
+
In RAR, nt_start
|
|
190
|
+
then corresponds to the initial number of points we train the PINN.
|
|
152
191
|
"""
|
|
153
192
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
batches. Batches are made so that each data point is seen only
|
|
174
|
-
once during 1 epoch.
|
|
175
|
-
tmin
|
|
176
|
-
A float. The minimum value of the time domain to consider
|
|
177
|
-
tmax
|
|
178
|
-
A float. The maximum value of the time domain to consider
|
|
179
|
-
temporal_batch_size
|
|
180
|
-
An integer. The size of the batch of randomly selected points among
|
|
181
|
-
the `nt` points.
|
|
182
|
-
method
|
|
183
|
-
Either `grid` or `uniform`, default is `uniform`.
|
|
184
|
-
The method that generates the `nt` time points. `grid` means
|
|
185
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
186
|
-
sampled points over the domain
|
|
187
|
-
rar_parameters
|
|
188
|
-
Default to None: do not use Residual Adaptative Resampling.
|
|
189
|
-
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
190
|
-
which we start the RAR sampling scheme (we first have a burn in
|
|
191
|
-
period). `update_every`: the number of gradient steps taken between
|
|
192
|
-
each appending of collocation points in the RAR algo.
|
|
193
|
-
`sample_size_times`: the size of the sample from which we will select new
|
|
194
|
-
collocation points. `selected_sample_size_times`: the number of selected
|
|
195
|
-
points from the sample to be added to the current collocation
|
|
196
|
-
points
|
|
197
|
-
"DeepXDE: A deep learning library for solving differential
|
|
198
|
-
equations", L. Lu, SIAM Review, 2021
|
|
199
|
-
nt_start
|
|
200
|
-
Defaults to None. The effective size of nt used at start time.
|
|
201
|
-
This value must be
|
|
202
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
203
|
-
nt_start = nt and this is hidden from the user.
|
|
204
|
-
In RAR, nt_start
|
|
205
|
-
then corresponds to the initial number of points we train the PINN.
|
|
206
|
-
data_exists
|
|
207
|
-
Must be left to `False` when created by the user. Avoids the
|
|
208
|
-
regeneration of the `nt` time points at each pytree flattening and
|
|
209
|
-
unflattening.
|
|
210
|
-
"""
|
|
211
|
-
self.data_exists = data_exists
|
|
212
|
-
self._key = key
|
|
213
|
-
self.nt = nt
|
|
214
|
-
self.tmin = tmin
|
|
215
|
-
self.tmax = tmax
|
|
216
|
-
self.temporal_batch_size = temporal_batch_size
|
|
217
|
-
self.method = method
|
|
218
|
-
self.rar_parameters = rar_parameters
|
|
219
|
-
|
|
220
|
-
# Set-up for RAR (if used)
|
|
193
|
+
key: Key
|
|
194
|
+
nt: Int
|
|
195
|
+
tmin: Float
|
|
196
|
+
tmax: Float
|
|
197
|
+
temporal_batch_size: Int = eqx.field(static=True) # static cause used as a
|
|
198
|
+
# shape in jax.lax.dynamic_slice
|
|
199
|
+
method: str = eqx.field(static=True, default_factory=lambda: "uniform")
|
|
200
|
+
rar_parameters: Dict[str, Int] = None
|
|
201
|
+
nt_start: Int = eqx.field(static=True, default=None)
|
|
202
|
+
|
|
203
|
+
# all the init=False fields are set in __post_init__, even after a _replace
|
|
204
|
+
# or eqx.tree_at __post_init__ is called
|
|
205
|
+
p_times: Float[Array, "nt"] = eqx.field(init=False)
|
|
206
|
+
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
207
|
+
rar_iter_nb: Int = eqx.field(init=False)
|
|
208
|
+
curr_time_idx: Int = eqx.field(init=False)
|
|
209
|
+
times: Float[Array, "nt"] = eqx.field(init=False)
|
|
210
|
+
|
|
211
|
+
def __post_init__(self):
|
|
221
212
|
(
|
|
222
213
|
self.nt_start,
|
|
223
214
|
self.p_times,
|
|
224
215
|
self.rar_iter_from_last_sampling,
|
|
225
216
|
self.rar_iter_nb,
|
|
226
|
-
) = _check_and_set_rar_parameters(rar_parameters,
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
217
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
|
|
218
|
+
|
|
219
|
+
self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
|
|
220
|
+
# to be sure there is a
|
|
221
|
+
# shuffling at first get_batch() we do not call
|
|
222
|
+
# _reset_batch_idx_and_permute in __init__ or __post_init__ because it
|
|
223
|
+
# would return a copy of self and we have not investigate what would
|
|
224
|
+
# happen
|
|
225
|
+
# NOTE the (- self.temporal_batch_size - 1) because otherwise when computing
|
|
226
|
+
# `bend` we overflow the max int32 with unwanted behaviour
|
|
227
|
+
|
|
228
|
+
self.key, self.times = self.generate_time_data(self.key)
|
|
229
|
+
# Note that, here, in __init__ (and __post_init__), this is the
|
|
230
|
+
# only place where self assignment are authorized so we do the
|
|
231
|
+
# above way for the key. Note that one of the motivation to return the
|
|
232
|
+
# key from generate_*_data is to easily align key with legacy
|
|
233
|
+
# DataGenerators to use same unit tests
|
|
234
|
+
|
|
235
|
+
def sample_in_time_domain(
|
|
236
|
+
self, key: Key, sample_size: Int = None
|
|
237
|
+
) -> Float[Array, "nt"]:
|
|
238
|
+
return jax.random.uniform(
|
|
239
|
+
key,
|
|
240
|
+
(self.nt if sample_size is None else sample_size,),
|
|
241
|
+
minval=self.tmin,
|
|
242
|
+
maxval=self.tmax,
|
|
243
|
+
)
|
|
240
244
|
|
|
241
|
-
def generate_time_data(self):
|
|
245
|
+
def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
|
|
242
246
|
"""
|
|
243
247
|
Construct a complete set of `self.nt` time points according to the
|
|
244
248
|
specified `self.method`
|
|
@@ -246,24 +250,28 @@ class DataGeneratorODE:
|
|
|
246
250
|
Note that self.times has always size self.nt and not self.nt_start, even
|
|
247
251
|
in RAR scheme, we must allocate all the collocation points
|
|
248
252
|
"""
|
|
253
|
+
key, subkey = jax.random.split(self.key)
|
|
249
254
|
if self.method == "grid":
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
255
|
+
partial_times = (self.tmax - self.tmin) / self.nt
|
|
256
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)
|
|
257
|
+
if self.method == "uniform":
|
|
258
|
+
return key, self.sample_in_time_domain(subkey)
|
|
259
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
256
260
|
|
|
257
|
-
def _get_time_operands(
|
|
261
|
+
def _get_time_operands(
|
|
262
|
+
self,
|
|
263
|
+
) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
|
|
258
264
|
return (
|
|
259
|
-
self.
|
|
265
|
+
self.key,
|
|
260
266
|
self.times,
|
|
261
267
|
self.curr_time_idx,
|
|
262
268
|
self.temporal_batch_size,
|
|
263
269
|
self.p_times,
|
|
264
270
|
)
|
|
265
271
|
|
|
266
|
-
def temporal_batch(
|
|
272
|
+
def temporal_batch(
|
|
273
|
+
self,
|
|
274
|
+
) -> tuple["DataGeneratorODE", Float[Array, "temporal_batch_size"]]:
|
|
267
275
|
"""
|
|
268
276
|
Return a batch of time points. If all the batches have been seen, we
|
|
269
277
|
reshuffle them, otherwise we just return the next unseen batch.
|
|
@@ -275,210 +283,142 @@ class DataGeneratorODE:
|
|
|
275
283
|
if self.rar_parameters is not None:
|
|
276
284
|
nt_eff = (
|
|
277
285
|
self.nt_start
|
|
278
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
286
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
|
|
279
287
|
)
|
|
280
288
|
else:
|
|
281
289
|
nt_eff = self.nt
|
|
282
|
-
|
|
283
|
-
|
|
290
|
+
|
|
291
|
+
new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
|
|
292
|
+
new = eqx.tree_at(
|
|
293
|
+
lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
|
|
284
294
|
)
|
|
285
295
|
|
|
286
296
|
# commands below are equivalent to
|
|
287
297
|
# return self.times[i:(i+t_batch_size)]
|
|
288
298
|
# start indices can be dynamic be the slice shape is fixed
|
|
289
|
-
return jax.lax.dynamic_slice(
|
|
290
|
-
|
|
291
|
-
start_indices=(
|
|
292
|
-
slice_sizes=(
|
|
299
|
+
return new, jax.lax.dynamic_slice(
|
|
300
|
+
new.times,
|
|
301
|
+
start_indices=(new.curr_time_idx,),
|
|
302
|
+
slice_sizes=(new.temporal_batch_size,),
|
|
293
303
|
)
|
|
294
304
|
|
|
295
|
-
def get_batch(self):
|
|
305
|
+
def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
|
|
296
306
|
"""
|
|
297
307
|
Generic method to return a batch. Here we call `self.temporal_batch()`
|
|
298
308
|
"""
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def tree_flatten(self):
|
|
302
|
-
children = (
|
|
303
|
-
self._key,
|
|
304
|
-
self.times,
|
|
305
|
-
self.curr_time_idx,
|
|
306
|
-
self.tmin,
|
|
307
|
-
self.tmax,
|
|
308
|
-
self.p_times,
|
|
309
|
-
self.rar_iter_from_last_sampling,
|
|
310
|
-
self.rar_iter_nb,
|
|
311
|
-
) # arrays / dynamic values
|
|
312
|
-
aux_data = {
|
|
313
|
-
k: vars(self)[k]
|
|
314
|
-
for k in [
|
|
315
|
-
"temporal_batch_size",
|
|
316
|
-
"method",
|
|
317
|
-
"nt",
|
|
318
|
-
"rar_parameters",
|
|
319
|
-
"nt_start",
|
|
320
|
-
]
|
|
321
|
-
} # static values
|
|
322
|
-
return (children, aux_data)
|
|
323
|
-
|
|
324
|
-
@classmethod
|
|
325
|
-
def tree_unflatten(cls, aux_data, children):
|
|
326
|
-
"""
|
|
327
|
-
**Note:** When reconstructing the class, we force ``data_exists=True``
|
|
328
|
-
in order not to re-generate the data at each flattening and
|
|
329
|
-
unflattening that happens e.g. during the gradient descent in the
|
|
330
|
-
optimization process
|
|
331
|
-
"""
|
|
332
|
-
(
|
|
333
|
-
key,
|
|
334
|
-
times,
|
|
335
|
-
curr_time_idx,
|
|
336
|
-
tmin,
|
|
337
|
-
tmax,
|
|
338
|
-
p_times,
|
|
339
|
-
rar_iter_from_last_sampling,
|
|
340
|
-
rar_iter_nb,
|
|
341
|
-
) = children
|
|
342
|
-
obj = cls(
|
|
343
|
-
key=key,
|
|
344
|
-
data_exists=True,
|
|
345
|
-
tmin=tmin,
|
|
346
|
-
tmax=tmax,
|
|
347
|
-
**aux_data,
|
|
348
|
-
)
|
|
349
|
-
obj.times = times
|
|
350
|
-
obj.curr_time_idx = curr_time_idx
|
|
351
|
-
obj.p_times = p_times
|
|
352
|
-
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
353
|
-
obj.rar_iter_nb = rar_iter_nb
|
|
354
|
-
return obj
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
##########################################
|
|
358
|
-
# Data Generator for PDE in stationnary
|
|
359
|
-
# and non-stationnary cases
|
|
360
|
-
##########################################
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
class DataGeneratorPDEAbstract:
|
|
364
|
-
"""generic skeleton class for a PDE data generator"""
|
|
309
|
+
new, temporal_batch = self.temporal_batch()
|
|
310
|
+
return new, ODEBatch(temporal_batch=temporal_batch)
|
|
365
311
|
|
|
366
|
-
def __init__(self, data_exists=False) -> None:
|
|
367
|
-
# /!\ WARNING /!\: an-end user should never create an object
|
|
368
|
-
# with data_exists=True. Or else generate_data() won't be called.
|
|
369
|
-
# Useful when using a lax.scan with a DataGenerator in the carry
|
|
370
|
-
# It tells JAX not to re-generate data in the __init__()
|
|
371
|
-
self.data_exists = data_exists
|
|
372
312
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
376
|
-
"""
|
|
313
|
+
class CubicMeshPDEStatio(eqx.Module):
|
|
314
|
+
r"""
|
|
377
315
|
A class implementing data generator object for stationary partial
|
|
378
316
|
differential equations.
|
|
379
317
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
key : Key
|
|
321
|
+
Jax random key to sample new time points and to shuffle batches
|
|
322
|
+
n : Int
|
|
323
|
+
The number of total $\Omega$ points that will be divided in
|
|
324
|
+
batches. Batches are made so that each data point is seen only
|
|
325
|
+
once during 1 epoch.
|
|
326
|
+
nb : Int | None
|
|
327
|
+
The total number of points in $\partial\Omega$.
|
|
328
|
+
Can be `None` not to lose performance generating the border
|
|
329
|
+
batch if they are not used
|
|
330
|
+
omega_batch_size : Int
|
|
331
|
+
The size of the batch of randomly selected points among
|
|
332
|
+
the `n` points.
|
|
333
|
+
omega_border_batch_size : Int | None
|
|
334
|
+
The size of the batch of points randomly selected
|
|
335
|
+
among the `nb` points.
|
|
336
|
+
Can be `None` not to lose performance generating the border
|
|
337
|
+
batch if they are not used
|
|
338
|
+
dim : Int
|
|
339
|
+
Dimension of $\Omega$ domain
|
|
340
|
+
min_pts : tuple[tuple[Float, Float], ...]
|
|
341
|
+
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
342
|
+
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
343
|
+
x_{n, min})$
|
|
344
|
+
max_pts : tuple[tuple[Float, Float], ...]
|
|
345
|
+
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
346
|
+
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
347
|
+
x_{n,max})$
|
|
348
|
+
method : str, default="uniform"
|
|
349
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
350
|
+
The method that generates the `nt` time points. `grid` means
|
|
351
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
352
|
+
sampled points over the domain
|
|
353
|
+
rar_parameters : Dict[str, Int], default=None
|
|
354
|
+
Default to None: do not use Residual Adaptative Resampling.
|
|
355
|
+
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
356
|
+
which we start the RAR sampling scheme (we first have a burn in
|
|
357
|
+
period). `update_every`: the number of gradient steps taken between
|
|
358
|
+
each appending of collocation points in the RAR algo.
|
|
359
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
360
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
361
|
+
points from the sample to be added to the current collocation
|
|
362
|
+
points
|
|
363
|
+
n_start : Int, default=None
|
|
364
|
+
Defaults to None. The effective size of n used at start time.
|
|
365
|
+
This value must be
|
|
366
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
367
|
+
n_start = n and this is hidden from the user.
|
|
368
|
+
In RAR, n_start
|
|
369
|
+
then corresponds to the initial number of points we train the PINN.
|
|
383
370
|
"""
|
|
384
371
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
)
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
dim
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
425
|
-
in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
|
|
426
|
-
x_{n, min})`
|
|
427
|
-
max_pts
|
|
428
|
-
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
429
|
-
in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
|
|
430
|
-
x_{n,max})`
|
|
431
|
-
method
|
|
432
|
-
Either `grid` or `uniform`, default is `grid`.
|
|
433
|
-
The method that generates the `nt` time points. `grid` means
|
|
434
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
435
|
-
sampled points over the domain
|
|
436
|
-
rar_parameters
|
|
437
|
-
Default to None: do not use Residual Adaptative Resampling.
|
|
438
|
-
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
439
|
-
which we start the RAR sampling scheme (we first have a burn in
|
|
440
|
-
period). `update_every`: the number of gradient steps taken between
|
|
441
|
-
each appending of collocation points in the RAR algo.
|
|
442
|
-
`sample_size_omega`: the size of the sample from which we will select new
|
|
443
|
-
collocation points. `selected_sample_size_omega`: the number of selected
|
|
444
|
-
points from the sample to be added to the current collocation
|
|
445
|
-
points
|
|
446
|
-
"DeepXDE: A deep learning library for solving differential
|
|
447
|
-
equations", L. Lu, SIAM Review, 2021
|
|
448
|
-
n_start
|
|
449
|
-
Defaults to None. The effective size of n used at start time.
|
|
450
|
-
This value must be
|
|
451
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
452
|
-
n_start = n and this is hidden from the user.
|
|
453
|
-
In RAR, n_start
|
|
454
|
-
then corresponds to the initial number of points we train the PINN.
|
|
455
|
-
data_exists
|
|
456
|
-
Must be left to `False` when created by the user. Avoids the
|
|
457
|
-
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
458
|
-
time points at each pytree flattening and unflattening.
|
|
459
|
-
"""
|
|
460
|
-
super().__init__(data_exists=data_exists)
|
|
461
|
-
self.method = method
|
|
462
|
-
self._key = key
|
|
463
|
-
self.dim = dim
|
|
464
|
-
self.min_pts = min_pts
|
|
465
|
-
self.max_pts = max_pts
|
|
466
|
-
assert dim == len(min_pts) and isinstance(min_pts, tuple)
|
|
467
|
-
assert dim == len(max_pts) and isinstance(max_pts, tuple)
|
|
468
|
-
self.n = n
|
|
469
|
-
self.rar_parameters = rar_parameters
|
|
372
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
373
|
+
key: Key = eqx.field(kw_only=True)
|
|
374
|
+
n: Int = eqx.field(kw_only=True)
|
|
375
|
+
nb: Int | None = eqx.field(kw_only=True)
|
|
376
|
+
omega_batch_size: Int = eqx.field(
|
|
377
|
+
kw_only=True, static=True
|
|
378
|
+
) # static cause used as a
|
|
379
|
+
# shape in jax.lax.dynamic_slice
|
|
380
|
+
omega_border_batch_size: Int | None = eqx.field(
|
|
381
|
+
kw_only=True, static=True
|
|
382
|
+
) # static cause used as a
|
|
383
|
+
# shape in jax.lax.dynamic_slice
|
|
384
|
+
dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
|
|
385
|
+
# shape in jax.lax.dynamic_slice
|
|
386
|
+
min_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
|
|
387
|
+
max_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
|
|
388
|
+
method: str = eqx.field(
|
|
389
|
+
kw_only=True, static=True, default_factory=lambda: "uniform"
|
|
390
|
+
)
|
|
391
|
+
rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
|
|
392
|
+
n_start: Int = eqx.field(kw_only=True, default=None, static=True)
|
|
393
|
+
|
|
394
|
+
# all the init=False fields are set in __post_init__, even after a _replace
|
|
395
|
+
# or eqx.tree_at __post_init__ is called
|
|
396
|
+
p_omega: Float[Array, "n"] = eqx.field(init=False)
|
|
397
|
+
p_border: None = eqx.field(init=False)
|
|
398
|
+
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
399
|
+
rar_iter_nb: Int = eqx.field(init=False)
|
|
400
|
+
curr_omega_idx: Int = eqx.field(init=False)
|
|
401
|
+
curr_omega_border_idx: Int = eqx.field(init=False)
|
|
402
|
+
omega: Float[Array, "n dim"] = eqx.field(init=False)
|
|
403
|
+
omega_border: Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None = eqx.field(
|
|
404
|
+
init=False
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
def __post_init__(self):
|
|
408
|
+
assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
|
|
409
|
+
assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
|
|
410
|
+
|
|
470
411
|
(
|
|
471
412
|
self.n_start,
|
|
472
413
|
self.p_omega,
|
|
473
414
|
self.rar_iter_from_last_sampling,
|
|
474
415
|
self.rar_iter_nb,
|
|
475
|
-
) = _check_and_set_rar_parameters(rar_parameters, n
|
|
416
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
476
417
|
|
|
477
418
|
self.p_border = None # no RAR sampling for border for now
|
|
478
419
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
if omega_border_batch_size is None:
|
|
420
|
+
# Special handling for the border batch
|
|
421
|
+
if self.omega_border_batch_size is None:
|
|
482
422
|
self.nb = None
|
|
483
423
|
self.omega_border_batch_size = None
|
|
484
424
|
elif self.dim == 1:
|
|
@@ -491,48 +431,46 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
491
431
|
# ignored since borders of Omega are singletons.
|
|
492
432
|
# self.border_batch() will return [xmin, xmax]
|
|
493
433
|
else:
|
|
494
|
-
if nb % (2 * self.dim) != 0 or nb < 2 * self.dim:
|
|
434
|
+
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
495
435
|
raise ValueError(
|
|
496
436
|
"number of border point must be"
|
|
497
437
|
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
498
438
|
)
|
|
499
|
-
if nb // (2 * self.dim) < omega_border_batch_size:
|
|
439
|
+
if self.nb // (2 * self.dim) < self.omega_border_batch_size:
|
|
500
440
|
raise ValueError(
|
|
501
441
|
"number of points per facets (nb//2*self.dim)"
|
|
502
442
|
" cannot be lower than border batch size"
|
|
503
443
|
)
|
|
504
|
-
self.nb = int((2 * self.dim) * (nb // (2 * self.dim)))
|
|
505
|
-
self.omega_border_batch_size = omega_border_batch_size
|
|
506
|
-
|
|
507
|
-
if not self.data_exists:
|
|
508
|
-
# Useful when using a lax.scan with pytree
|
|
509
|
-
# Optionally tells JAX not to re-generate data when re-building the
|
|
510
|
-
# object
|
|
511
|
-
self.curr_omega_idx = 0
|
|
512
|
-
self.curr_omega_border_idx = 0
|
|
513
|
-
self.generate_data()
|
|
514
|
-
self._key, self.omega, _ = _reset_batch_idx_and_permute(
|
|
515
|
-
self._get_omega_operands()
|
|
516
|
-
)
|
|
517
|
-
if self.omega_border is not None and self.dim > 1:
|
|
518
|
-
self._key, self.omega_border, _ = _reset_batch_idx_and_permute(
|
|
519
|
-
self._get_omega_border_operands()
|
|
520
|
-
)
|
|
444
|
+
self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
|
|
521
445
|
|
|
522
|
-
|
|
446
|
+
self.curr_omega_idx = jnp.iinfo(jnp.int32).max - self.omega_batch_size - 1
|
|
447
|
+
# see explaination in DataGeneratorODE
|
|
448
|
+
if self.omega_border_batch_size is None:
|
|
449
|
+
self.curr_omega_border_idx = None
|
|
450
|
+
else:
|
|
451
|
+
self.curr_omega_border_idx = (
|
|
452
|
+
jnp.iinfo(jnp.int32).max - self.omega_border_batch_size - 1
|
|
453
|
+
)
|
|
454
|
+
# key, subkey = jax.random.split(self.key)
|
|
455
|
+
# self.key = key
|
|
456
|
+
self.key, self.omega, self.omega_border = self.generate_data(self.key)
|
|
457
|
+
# see explaination in DataGeneratorODE for the key
|
|
458
|
+
|
|
459
|
+
def sample_in_omega_domain(
|
|
460
|
+
self, keys: Key, sample_size: Int = None
|
|
461
|
+
) -> Float[Array, "n dim"]:
|
|
462
|
+
sample_size = self.n if sample_size is None else sample_size
|
|
523
463
|
if self.dim == 1:
|
|
524
464
|
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
subkey, shape=(n_samples, 1), minval=xmin, maxval=xmax
|
|
465
|
+
return jax.random.uniform(
|
|
466
|
+
keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
528
467
|
)
|
|
529
|
-
keys = random.split(
|
|
530
|
-
self._key = keys[0]
|
|
468
|
+
# keys = jax.random.split(key, self.dim)
|
|
531
469
|
return jnp.concatenate(
|
|
532
470
|
[
|
|
533
|
-
random.uniform(
|
|
534
|
-
keys[i
|
|
535
|
-
(
|
|
471
|
+
jax.random.uniform(
|
|
472
|
+
keys[i],
|
|
473
|
+
(sample_size, 1),
|
|
536
474
|
minval=self.min_pts[i],
|
|
537
475
|
maxval=self.max_pts[i],
|
|
538
476
|
)
|
|
@@ -541,7 +479,9 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
541
479
|
axis=-1,
|
|
542
480
|
)
|
|
543
481
|
|
|
544
|
-
def sample_in_omega_border_domain(
|
|
482
|
+
def sample_in_omega_border_domain(
|
|
483
|
+
self, keys: Key
|
|
484
|
+
) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
|
|
545
485
|
if self.omega_border_batch_size is None:
|
|
546
486
|
return None
|
|
547
487
|
if self.dim == 1:
|
|
@@ -553,15 +493,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
553
493
|
# TODO : find a general & efficient way to sample from the border
|
|
554
494
|
# (facets) of the hypercube in general dim.
|
|
555
495
|
|
|
556
|
-
facet_n =
|
|
557
|
-
keys = random.split(self._key, 5)
|
|
558
|
-
self._key = keys[0]
|
|
559
|
-
subkeys = keys[1:]
|
|
496
|
+
facet_n = self.nb // (2 * self.dim)
|
|
560
497
|
xmin = jnp.hstack(
|
|
561
498
|
[
|
|
562
499
|
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
563
|
-
random.uniform(
|
|
564
|
-
|
|
500
|
+
jax.random.uniform(
|
|
501
|
+
keys[0],
|
|
565
502
|
(facet_n, 1),
|
|
566
503
|
minval=self.min_pts[1],
|
|
567
504
|
maxval=self.max_pts[1],
|
|
@@ -571,8 +508,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
571
508
|
xmax = jnp.hstack(
|
|
572
509
|
[
|
|
573
510
|
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
574
|
-
random.uniform(
|
|
575
|
-
|
|
511
|
+
jax.random.uniform(
|
|
512
|
+
keys[1],
|
|
576
513
|
(facet_n, 1),
|
|
577
514
|
minval=self.min_pts[1],
|
|
578
515
|
maxval=self.max_pts[1],
|
|
@@ -581,8 +518,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
581
518
|
)
|
|
582
519
|
ymin = jnp.hstack(
|
|
583
520
|
[
|
|
584
|
-
random.uniform(
|
|
585
|
-
|
|
521
|
+
jax.random.uniform(
|
|
522
|
+
keys[2],
|
|
586
523
|
(facet_n, 1),
|
|
587
524
|
minval=self.min_pts[0],
|
|
588
525
|
maxval=self.max_pts[0],
|
|
@@ -592,8 +529,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
592
529
|
)
|
|
593
530
|
ymax = jnp.hstack(
|
|
594
531
|
[
|
|
595
|
-
random.uniform(
|
|
596
|
-
|
|
532
|
+
jax.random.uniform(
|
|
533
|
+
keys[3],
|
|
597
534
|
(facet_n, 1),
|
|
598
535
|
minval=self.min_pts[0],
|
|
599
536
|
maxval=self.max_pts[0],
|
|
@@ -607,54 +544,71 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
607
544
|
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
608
545
|
)
|
|
609
546
|
|
|
610
|
-
def generate_data(self)
|
|
547
|
+
def generate_data(self, key: Key) -> tuple[
|
|
548
|
+
Key,
|
|
549
|
+
Float[Array, "n dim"],
|
|
550
|
+
Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
|
|
551
|
+
]:
|
|
611
552
|
r"""
|
|
612
|
-
Construct a complete set of `self.n`
|
|
553
|
+
Construct a complete set of `self.n` $\Omega$ points according to the
|
|
613
554
|
specified `self.method`. Also constructs a complete set of `self.nb`
|
|
614
|
-
|
|
555
|
+
$\partial\Omega$ points if `self.omega_border_batch_size` is not
|
|
615
556
|
`None`. If the latter is `None` we set `self.omega_border` to `None`.
|
|
616
557
|
"""
|
|
617
|
-
|
|
618
558
|
# Generate Omega
|
|
619
559
|
if self.method == "grid":
|
|
620
560
|
if self.dim == 1:
|
|
621
561
|
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
622
|
-
|
|
562
|
+
partial = (xmax - xmin) / self.n
|
|
623
563
|
# shape (n, 1)
|
|
624
|
-
|
|
564
|
+
omega = jnp.arange(xmin, xmax, partial)[:, None]
|
|
625
565
|
else:
|
|
626
|
-
|
|
566
|
+
partials = [
|
|
627
567
|
(self.max_pts[i] - self.min_pts[i]) / jnp.sqrt(self.n)
|
|
628
568
|
for i in range(self.dim)
|
|
629
569
|
]
|
|
630
570
|
xyz_ = jnp.meshgrid(
|
|
631
571
|
*[
|
|
632
|
-
jnp.arange(self.min_pts[i], self.max_pts[i],
|
|
572
|
+
jnp.arange(self.min_pts[i], self.max_pts[i], partials[i])
|
|
633
573
|
for i in range(self.dim)
|
|
634
574
|
]
|
|
635
575
|
)
|
|
636
576
|
xyz_ = [a.reshape((self.n, 1)) for a in xyz_]
|
|
637
|
-
|
|
577
|
+
omega = jnp.concatenate(xyz_, axis=-1)
|
|
638
578
|
elif self.method == "uniform":
|
|
639
|
-
self.
|
|
579
|
+
if self.dim == 1:
|
|
580
|
+
key, subkeys = jax.random.split(key, 2)
|
|
581
|
+
else:
|
|
582
|
+
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
583
|
+
omega = self.sample_in_omega_domain(subkeys)
|
|
640
584
|
else:
|
|
641
585
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
642
586
|
|
|
643
587
|
# Generate border of omega
|
|
644
|
-
self.
|
|
588
|
+
if self.dim == 2 and self.omega_border_batch_size is not None:
|
|
589
|
+
key, *subkeys = jax.random.split(key, 5)
|
|
590
|
+
else:
|
|
591
|
+
subkeys = None
|
|
592
|
+
omega_border = self.sample_in_omega_border_domain(subkeys)
|
|
645
593
|
|
|
646
|
-
|
|
594
|
+
return key, omega, omega_border
|
|
595
|
+
|
|
596
|
+
def _get_omega_operands(
|
|
597
|
+
self,
|
|
598
|
+
) -> tuple[Key, Float[Array, "n dim"], Int, Int, Float[Array, "n"]]:
|
|
647
599
|
return (
|
|
648
|
-
self.
|
|
600
|
+
self.key,
|
|
649
601
|
self.omega,
|
|
650
602
|
self.curr_omega_idx,
|
|
651
603
|
self.omega_batch_size,
|
|
652
604
|
self.p_omega,
|
|
653
605
|
)
|
|
654
606
|
|
|
655
|
-
def inside_batch(
|
|
607
|
+
def inside_batch(
|
|
608
|
+
self,
|
|
609
|
+
) -> tuple["CubicMeshPDEStatio", Float[Array, "omega_batch_size dim"]]:
|
|
656
610
|
r"""
|
|
657
|
-
Return a batch of points in
|
|
611
|
+
Return a batch of points in $\Omega$.
|
|
658
612
|
If all the batches have been seen, we reshuffle them,
|
|
659
613
|
otherwise we just return the next unseen batch.
|
|
660
614
|
"""
|
|
@@ -670,38 +624,46 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
670
624
|
bstart = self.curr_omega_idx
|
|
671
625
|
bend = bstart + self.omega_batch_size
|
|
672
626
|
|
|
673
|
-
(
|
|
674
|
-
|
|
627
|
+
new_attributes = _reset_or_increment(bend, n_eff, self._get_omega_operands())
|
|
628
|
+
new = eqx.tree_at(
|
|
629
|
+
lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
|
|
675
630
|
)
|
|
676
631
|
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
start_indices=(self.curr_omega_idx, 0),
|
|
682
|
-
slice_sizes=(self.omega_batch_size, self.dim),
|
|
632
|
+
return new, jax.lax.dynamic_slice(
|
|
633
|
+
new.omega,
|
|
634
|
+
start_indices=(new.curr_omega_idx, 0),
|
|
635
|
+
slice_sizes=(new.omega_batch_size, new.dim),
|
|
683
636
|
)
|
|
684
637
|
|
|
685
|
-
def _get_omega_border_operands(
|
|
638
|
+
def _get_omega_border_operands(
|
|
639
|
+
self,
|
|
640
|
+
) -> tuple[
|
|
641
|
+
Key, Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None, Int, Int, None
|
|
642
|
+
]:
|
|
686
643
|
return (
|
|
687
|
-
self.
|
|
644
|
+
self.key,
|
|
688
645
|
self.omega_border,
|
|
689
646
|
self.curr_omega_border_idx,
|
|
690
647
|
self.omega_border_batch_size,
|
|
691
648
|
self.p_border,
|
|
692
649
|
)
|
|
693
650
|
|
|
694
|
-
def border_batch(
|
|
651
|
+
def border_batch(
|
|
652
|
+
self,
|
|
653
|
+
) -> tuple[
|
|
654
|
+
"CubicMeshPDEStatio",
|
|
655
|
+
Float[Array, "1 1 2"] | Float[Array, "omega_border_batch_size 2 4"] | None,
|
|
656
|
+
]:
|
|
695
657
|
r"""
|
|
696
658
|
Return
|
|
697
659
|
|
|
698
660
|
- The value `None` if `self.omega_border_batch_size` is `None`.
|
|
699
661
|
|
|
700
|
-
- a jnp array with two fixed values
|
|
662
|
+
- a jnp array with two fixed values $(x_{min}, x_{max})$ if
|
|
701
663
|
`self.dim` = 1. There is no sampling here, we return the entire
|
|
702
|
-
|
|
664
|
+
$\partial\Omega$
|
|
703
665
|
|
|
704
|
-
- a batch of points in
|
|
666
|
+
- a batch of points in $\partial\Omega$ otherwise, stacked by
|
|
705
667
|
facet on the last axis.
|
|
706
668
|
If all the batches have been seen, we reshuffle them,
|
|
707
669
|
otherwise we just return the next unseen batch.
|
|
@@ -709,236 +671,138 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
709
671
|
|
|
710
672
|
"""
|
|
711
673
|
if self.omega_border_batch_size is None:
|
|
712
|
-
return None
|
|
674
|
+
return self, None
|
|
713
675
|
if self.dim == 1:
|
|
714
676
|
# 1-D case, no randomness : we always return the whole omega border,
|
|
715
677
|
# i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
|
|
716
|
-
return self.omega_border[None, None] # shape is (1, 1, 2)
|
|
678
|
+
return self, self.omega_border[None, None] # shape is (1, 1, 2)
|
|
717
679
|
bstart = self.curr_omega_border_idx
|
|
718
680
|
bend = bstart + self.omega_border_batch_size
|
|
719
681
|
|
|
720
|
-
(
|
|
721
|
-
self.
|
|
722
|
-
self.omega_border,
|
|
723
|
-
self.curr_omega_border_idx,
|
|
724
|
-
) = _reset_or_increment(bend, self.nb, self._get_omega_border_operands())
|
|
725
|
-
|
|
726
|
-
# commands below are equivalent to
|
|
727
|
-
# return self.omega[i:(i+batch_size), 0:dim, 0:nb_facets]
|
|
728
|
-
# and nb_facets = 2 * dimension
|
|
729
|
-
# but JAX prefer the latter
|
|
730
|
-
return jax.lax.dynamic_slice(
|
|
731
|
-
self.omega_border,
|
|
732
|
-
start_indices=(self.curr_omega_border_idx, 0, 0),
|
|
733
|
-
slice_sizes=(self.omega_border_batch_size, self.dim, 2 * self.dim),
|
|
682
|
+
new_attributes = _reset_or_increment(
|
|
683
|
+
bend, self.nb, self._get_omega_border_operands()
|
|
734
684
|
)
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
and `self.border_batch()`
|
|
740
|
-
"""
|
|
741
|
-
return PDEStatioBatch(
|
|
742
|
-
inside_batch=self.inside_batch(), border_batch=self.border_batch()
|
|
685
|
+
new = eqx.tree_at(
|
|
686
|
+
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
|
|
687
|
+
self,
|
|
688
|
+
new_attributes,
|
|
743
689
|
)
|
|
744
690
|
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
self.omega_border,
|
|
750
|
-
self.curr_omega_idx,
|
|
751
|
-
self.curr_omega_border_idx,
|
|
752
|
-
self.min_pts,
|
|
753
|
-
self.max_pts,
|
|
754
|
-
self.p_omega,
|
|
755
|
-
self.rar_iter_from_last_sampling,
|
|
756
|
-
self.rar_iter_nb,
|
|
691
|
+
return new, jax.lax.dynamic_slice(
|
|
692
|
+
new.omega_border,
|
|
693
|
+
start_indices=(new.curr_omega_border_idx, 0, 0),
|
|
694
|
+
slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
|
|
757
695
|
)
|
|
758
|
-
aux_data = {
|
|
759
|
-
k: vars(self)[k]
|
|
760
|
-
for k in [
|
|
761
|
-
"n",
|
|
762
|
-
"nb",
|
|
763
|
-
"omega_batch_size",
|
|
764
|
-
"omega_border_batch_size",
|
|
765
|
-
"method",
|
|
766
|
-
"dim",
|
|
767
|
-
"rar_parameters",
|
|
768
|
-
"n_start",
|
|
769
|
-
]
|
|
770
|
-
}
|
|
771
|
-
return (children, aux_data)
|
|
772
696
|
|
|
773
|
-
|
|
774
|
-
def tree_unflatten(cls, aux_data, children):
|
|
697
|
+
def get_batch(self) -> tuple["CubicMeshPDEStatio", PDEStatioBatch]:
|
|
775
698
|
"""
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
unflattening that happens e.g. during the gradient descent in the
|
|
779
|
-
optimization process
|
|
699
|
+
Generic method to return a batch. Here we call `self.inside_batch()`
|
|
700
|
+
and `self.border_batch()`
|
|
780
701
|
"""
|
|
781
|
-
(
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
omega_border,
|
|
785
|
-
curr_omega_idx,
|
|
786
|
-
curr_omega_border_idx,
|
|
787
|
-
min_pts,
|
|
788
|
-
max_pts,
|
|
789
|
-
p_omega,
|
|
790
|
-
rar_iter_from_last_sampling,
|
|
791
|
-
rar_iter_nb,
|
|
792
|
-
) = children
|
|
793
|
-
# force data_exists=True here in order not to re-generate the data
|
|
794
|
-
# at each iteration of lax.scan
|
|
795
|
-
obj = cls(
|
|
796
|
-
key=key,
|
|
797
|
-
data_exists=True,
|
|
798
|
-
min_pts=min_pts,
|
|
799
|
-
max_pts=max_pts,
|
|
800
|
-
**aux_data,
|
|
801
|
-
)
|
|
802
|
-
obj.omega = omega
|
|
803
|
-
obj.omega_border = omega_border
|
|
804
|
-
obj.curr_omega_idx = curr_omega_idx
|
|
805
|
-
obj.curr_omega_border_idx = curr_omega_border_idx
|
|
806
|
-
obj.p_omega = p_omega
|
|
807
|
-
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
808
|
-
obj.rar_iter_nb = rar_iter_nb
|
|
809
|
-
return obj
|
|
702
|
+
new, inside_batch = self.inside_batch()
|
|
703
|
+
new, border_batch = new.border_batch()
|
|
704
|
+
return new, PDEStatioBatch(inside_batch=inside_batch, border_batch=border_batch)
|
|
810
705
|
|
|
811
706
|
|
|
812
|
-
@register_pytree_node_class
|
|
813
707
|
class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
814
|
-
"""
|
|
708
|
+
r"""
|
|
815
709
|
A class implementing data generator object for non stationary partial
|
|
816
710
|
differential equations. Formally, it extends `CubicMeshPDEStatio`
|
|
817
711
|
to include a temporal batch.
|
|
818
712
|
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
713
|
+
Parameters
|
|
714
|
+
----------
|
|
715
|
+
key : Key
|
|
716
|
+
Jax random key to sample new time points and to shuffle batches
|
|
717
|
+
n : Int
|
|
718
|
+
The number of total $\Omega$ points that will be divided in
|
|
719
|
+
batches. Batches are made so that each data point is seen only
|
|
720
|
+
once during 1 epoch.
|
|
721
|
+
nb : Int | None
|
|
722
|
+
The total number of points in $\partial\Omega$.
|
|
723
|
+
Can be `None` not to lose performance generating the border
|
|
724
|
+
batch if they are not used
|
|
725
|
+
nt : Int
|
|
726
|
+
The number of total time points that will be divided in
|
|
727
|
+
batches. Batches are made so that each data point is seen only
|
|
728
|
+
once during 1 epoch.
|
|
729
|
+
omega_batch_size : Int
|
|
730
|
+
The size of the batch of randomly selected points among
|
|
731
|
+
the `n` points.
|
|
732
|
+
omega_border_batch_size : Int | None
|
|
733
|
+
The size of the batch of points randomly selected
|
|
734
|
+
among the `nb` points.
|
|
735
|
+
Can be `None` not to lose performance generating the border
|
|
736
|
+
batch if they are not used
|
|
737
|
+
temporal_batch_size : Int
|
|
738
|
+
The size of the batch of randomly selected points among
|
|
739
|
+
the `nt` points.
|
|
740
|
+
dim : Int
|
|
741
|
+
An integer. dimension of $\Omega$ domain
|
|
742
|
+
min_pts : tuple[tuple[Float, Float], ...]
|
|
743
|
+
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
744
|
+
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
745
|
+
x_{n, min})$
|
|
746
|
+
max_pts : tuple[tuple[Float, Float], ...]
|
|
747
|
+
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
748
|
+
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
749
|
+
x_{n,max})$
|
|
750
|
+
tmin : float
|
|
751
|
+
The minimum value of the time domain to consider
|
|
752
|
+
tmax : float
|
|
753
|
+
The maximum value of the time domain to consider
|
|
754
|
+
method : str, default="uniform"
|
|
755
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
756
|
+
The method that generates the `nt` time points. `grid` means
|
|
757
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
758
|
+
sampled points over the domain
|
|
759
|
+
rar_parameters : Dict[str, Int], default=None
|
|
760
|
+
Default to None: do not use Residual Adaptative Resampling.
|
|
761
|
+
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
762
|
+
which we start the RAR sampling scheme (we first have a burn in
|
|
763
|
+
period). `update_every`: the number of gradient steps taken between
|
|
764
|
+
each appending of collocation points in the RAR algo.
|
|
765
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
766
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
767
|
+
points from the sample to be added to the current collocation
|
|
768
|
+
points.
|
|
769
|
+
n_start : Int, default=None
|
|
770
|
+
Defaults to None. The effective size of n used at start time.
|
|
771
|
+
This value must be
|
|
772
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
773
|
+
n_start = n and this is hidden from the user.
|
|
774
|
+
In RAR, n_start
|
|
775
|
+
then corresponds to the initial number of omega points we train the PINN.
|
|
776
|
+
nt_start : Int, default=None
|
|
777
|
+
Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
|
|
778
|
+
for times collocation point. See also ``DataGeneratorODE``
|
|
779
|
+
documentation.
|
|
780
|
+
cartesian_product : Bool, default=True
|
|
781
|
+
Defaults to True. Whether we return the cartesian product of the
|
|
782
|
+
temporal batch with the inside and border batches. If False we just
|
|
783
|
+
return their concatenation.
|
|
822
784
|
"""
|
|
823
785
|
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
tmax,
|
|
838
|
-
method="grid",
|
|
839
|
-
rar_parameters=None,
|
|
840
|
-
n_start=None,
|
|
841
|
-
nt_start=None,
|
|
842
|
-
cartesian_product=True,
|
|
843
|
-
data_exists=False,
|
|
844
|
-
):
|
|
845
|
-
r"""
|
|
846
|
-
Parameters
|
|
847
|
-
----------
|
|
848
|
-
key
|
|
849
|
-
Jax random key to sample new time points and to shuffle batches
|
|
850
|
-
n
|
|
851
|
-
An integer. The number of total :math:`\Omega` points that will be divided in
|
|
852
|
-
batches. Batches are made so that each data point is seen only
|
|
853
|
-
once during 1 epoch.
|
|
854
|
-
nb
|
|
855
|
-
An integer. The total number of points in :math:`\partial\Omega`.
|
|
856
|
-
Can be `None` not to lose performance generating the border
|
|
857
|
-
batch if they are not used
|
|
858
|
-
nt
|
|
859
|
-
An integer. The number of total time points that will be divided in
|
|
860
|
-
batches. Batches are made so that each data point is seen only
|
|
861
|
-
once during 1 epoch.
|
|
862
|
-
omega_batch_size
|
|
863
|
-
An integer. The size of the batch of randomly selected points among
|
|
864
|
-
the `n` points.
|
|
865
|
-
omega_border_batch_size
|
|
866
|
-
An integer. The size of the batch of points randomly selected
|
|
867
|
-
among the `nb` points.
|
|
868
|
-
Can be `None` not to lose performance generating the border
|
|
869
|
-
batch if they are not used
|
|
870
|
-
temporal_batch_size
|
|
871
|
-
An integer. The size of the batch of randomly selected points among
|
|
872
|
-
the `nt` points.
|
|
873
|
-
dim
|
|
874
|
-
An integer. dimension of :math:`\Omega` domain
|
|
875
|
-
min_pts
|
|
876
|
-
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
877
|
-
in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
|
|
878
|
-
x_{n, min})`
|
|
879
|
-
max_pts
|
|
880
|
-
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
881
|
-
in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
|
|
882
|
-
x_{n,max})`
|
|
883
|
-
tmin
|
|
884
|
-
A float. The minimum value of the time domain to consider
|
|
885
|
-
tmax
|
|
886
|
-
A float. The maximum value of the time domain to consider
|
|
887
|
-
method
|
|
888
|
-
Either `grid` or `uniform`, default is `grid`.
|
|
889
|
-
The method that generates the `nt` time points. `grid` means
|
|
890
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
891
|
-
sampled points over the domain
|
|
892
|
-
rar_parameters
|
|
893
|
-
Default to None: do not use Residual Adaptative Resampling.
|
|
894
|
-
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
895
|
-
which we start the RAR sampling scheme (we first have a burn in
|
|
896
|
-
period). `update_every`: the number of gradient steps taken between
|
|
897
|
-
each appending of collocation points in the RAR algo.
|
|
898
|
-
`sample_size_omega`: the size of the sample from which we will select new
|
|
899
|
-
collocation points. `selected_sample_size_omega`: the number of selected
|
|
900
|
-
points from the sample to be added to the current collocation
|
|
901
|
-
points.
|
|
902
|
-
n_start
|
|
903
|
-
Defaults to None. The effective size of n used at start time.
|
|
904
|
-
This value must be
|
|
905
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
906
|
-
n_start = n and this is hidden from the user.
|
|
907
|
-
In RAR, n_start
|
|
908
|
-
then corresponds to the initial number of omega points we train the PINN.
|
|
909
|
-
nt_start
|
|
910
|
-
Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
|
|
911
|
-
for times collocation point. See also ``DataGeneratorODE``
|
|
912
|
-
documentation.
|
|
913
|
-
cartesian_product
|
|
914
|
-
Defaults to True. Whether we return the cartesian product of the
|
|
915
|
-
temporal batch with the inside and border batches. If False we just
|
|
916
|
-
return their concatenation.
|
|
917
|
-
data_exists
|
|
918
|
-
Must be left to `False` when created by the user. Avoids the
|
|
919
|
-
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
920
|
-
time points at each pytree flattening and unflattening.
|
|
786
|
+
temporal_batch_size: Int = eqx.field(kw_only=True)
|
|
787
|
+
tmin: Float = eqx.field(kw_only=True)
|
|
788
|
+
tmax: Float = eqx.field(kw_only=True)
|
|
789
|
+
nt: Int = eqx.field(kw_only=True)
|
|
790
|
+
temporal_batch_size: Int = eqx.field(kw_only=True, static=True)
|
|
791
|
+
cartesian_product: Bool = eqx.field(kw_only=True, default=True, static=True)
|
|
792
|
+
nt_start: int = eqx.field(kw_only=True, default=None, static=True)
|
|
793
|
+
|
|
794
|
+
p_times: Array = eqx.field(init=False)
|
|
795
|
+
curr_time_idx: Int = eqx.field(init=False)
|
|
796
|
+
times: Array = eqx.field(init=False)
|
|
797
|
+
|
|
798
|
+
def __post_init__(self):
|
|
921
799
|
"""
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
omega_border_batch_size,
|
|
928
|
-
dim,
|
|
929
|
-
min_pts,
|
|
930
|
-
max_pts,
|
|
931
|
-
method,
|
|
932
|
-
rar_parameters,
|
|
933
|
-
n_start,
|
|
934
|
-
data_exists,
|
|
935
|
-
)
|
|
936
|
-
self.temporal_batch_size = temporal_batch_size
|
|
937
|
-
self.tmin = tmin
|
|
938
|
-
self.tmax = tmax
|
|
939
|
-
self.nt = nt
|
|
800
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
801
|
+
Module with eqx.tree_at!
|
|
802
|
+
"""
|
|
803
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
804
|
+
# class is not automatically called
|
|
940
805
|
|
|
941
|
-
self.cartesian_product = cartesian_product
|
|
942
806
|
if not self.cartesian_product:
|
|
943
807
|
if self.temporal_batch_size != self.omega_batch_size:
|
|
944
808
|
raise ValueError(
|
|
@@ -968,46 +832,54 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
968
832
|
self.p_times,
|
|
969
833
|
_,
|
|
970
834
|
_,
|
|
971
|
-
) = _check_and_set_rar_parameters(rar_parameters,
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
835
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
|
|
836
|
+
|
|
837
|
+
self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
|
|
838
|
+
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
839
|
+
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
840
|
+
self.key, self.times = self.generate_time_data(self.key)
|
|
841
|
+
# see explaination in DataGeneratorODE for the key
|
|
842
|
+
|
|
843
|
+
def sample_in_time_domain(
|
|
844
|
+
self, key: Key, sample_size: Int = None
|
|
845
|
+
) -> Float[Array, "nt"]:
|
|
846
|
+
return jax.random.uniform(
|
|
847
|
+
key,
|
|
848
|
+
(self.nt if sample_size is None else sample_size,),
|
|
849
|
+
minval=self.tmin,
|
|
850
|
+
maxval=self.tmax,
|
|
851
|
+
)
|
|
985
852
|
|
|
986
|
-
def _get_time_operands(
|
|
853
|
+
def _get_time_operands(
|
|
854
|
+
self,
|
|
855
|
+
) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
|
|
987
856
|
return (
|
|
988
|
-
self.
|
|
857
|
+
self.key,
|
|
989
858
|
self.times,
|
|
990
859
|
self.curr_time_idx,
|
|
991
860
|
self.temporal_batch_size,
|
|
992
861
|
self.p_times,
|
|
993
862
|
)
|
|
994
863
|
|
|
995
|
-
def
|
|
996
|
-
|
|
864
|
+
def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
|
|
865
|
+
"""
|
|
997
866
|
Construct a complete set of `self.nt` time points according to the
|
|
998
|
-
specified `self.method
|
|
999
|
-
|
|
1000
|
-
|
|
867
|
+
specified `self.method`
|
|
868
|
+
|
|
869
|
+
Note that self.times has always size self.nt and not self.nt_start, even
|
|
870
|
+
in RAR scheme, we must allocate all the collocation points
|
|
1001
871
|
"""
|
|
872
|
+
key, subkey = jax.random.split(key, 2)
|
|
1002
873
|
if self.method == "grid":
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
874
|
+
partial_times = (self.tmax - self.tmin) / self.nt
|
|
875
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)
|
|
876
|
+
if self.method == "uniform":
|
|
877
|
+
return key, self.sample_in_time_domain(subkey)
|
|
878
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
1009
879
|
|
|
1010
|
-
def temporal_batch(
|
|
880
|
+
def temporal_batch(
|
|
881
|
+
self,
|
|
882
|
+
) -> tuple["CubicMeshPDENonStatio", Float[Array, "temporal_batch_size"]]:
|
|
1011
883
|
"""
|
|
1012
884
|
Return a batch of time points. If all the batches have been seen, we
|
|
1013
885
|
reshuffle them, otherwise we just return the next unseen batch.
|
|
@@ -1018,254 +890,344 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1018
890
|
# Compute the effective number of used collocation points
|
|
1019
891
|
if self.rar_parameters is not None:
|
|
1020
892
|
nt_eff = (
|
|
1021
|
-
self.
|
|
893
|
+
self.nt_start
|
|
1022
894
|
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
|
|
1023
895
|
)
|
|
1024
896
|
else:
|
|
1025
897
|
nt_eff = self.nt
|
|
1026
898
|
|
|
1027
|
-
(
|
|
1028
|
-
|
|
899
|
+
new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
|
|
900
|
+
new = eqx.tree_at(
|
|
901
|
+
lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
|
|
1029
902
|
)
|
|
1030
903
|
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
self.times,
|
|
1036
|
-
start_indices=(self.curr_time_idx,),
|
|
1037
|
-
slice_sizes=(self.temporal_batch_size,),
|
|
904
|
+
return new, jax.lax.dynamic_slice(
|
|
905
|
+
new.times,
|
|
906
|
+
start_indices=(new.curr_time_idx,),
|
|
907
|
+
slice_sizes=(new.temporal_batch_size,),
|
|
1038
908
|
)
|
|
1039
909
|
|
|
1040
|
-
def get_batch(self):
|
|
910
|
+
def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
|
|
1041
911
|
"""
|
|
1042
912
|
Generic method to return a batch. Here we call `self.inside_batch()`,
|
|
1043
913
|
`self.border_batch()` and `self.temporal_batch()`
|
|
1044
914
|
"""
|
|
1045
|
-
x = self.inside_batch()
|
|
1046
|
-
dx =
|
|
1047
|
-
t =
|
|
915
|
+
new, x = self.inside_batch()
|
|
916
|
+
new, dx = new.border_batch()
|
|
917
|
+
new, t = new.temporal_batch()
|
|
918
|
+
t = t.reshape(new.temporal_batch_size, 1)
|
|
1048
919
|
|
|
1049
|
-
if
|
|
920
|
+
if new.cartesian_product:
|
|
1050
921
|
t_x = make_cartesian_product(t, x)
|
|
1051
922
|
else:
|
|
1052
923
|
t_x = jnp.concatenate([t, x], axis=1)
|
|
1053
924
|
|
|
1054
925
|
if dx is not None:
|
|
1055
|
-
t_ = t.reshape(
|
|
926
|
+
t_ = t.reshape(new.temporal_batch_size, 1, 1)
|
|
1056
927
|
t_ = jnp.repeat(t_, dx.shape[-1], axis=2)
|
|
1057
|
-
if
|
|
928
|
+
if new.cartesian_product or new.dim == 1:
|
|
1058
929
|
t_dx = make_cartesian_product(t_, dx)
|
|
1059
930
|
else:
|
|
1060
931
|
t_dx = jnp.concatenate([t_, dx], axis=1)
|
|
1061
932
|
else:
|
|
1062
933
|
t_dx = None
|
|
1063
934
|
|
|
1064
|
-
return PDENonStatioBatch(
|
|
935
|
+
return new, PDENonStatioBatch(
|
|
936
|
+
times_x_inside_batch=t_x, times_x_border_batch=t_dx
|
|
937
|
+
)
|
|
1065
938
|
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
939
|
+
|
|
940
|
+
class DataGeneratorObservations(eqx.Module):
|
|
941
|
+
r"""
|
|
942
|
+
Despite the class name, it is rather a dataloader from user provided
|
|
943
|
+
observations that will be used for the observations loss
|
|
944
|
+
|
|
945
|
+
Parameters
|
|
946
|
+
----------
|
|
947
|
+
key : Key
|
|
948
|
+
Jax random key to shuffle batches
|
|
949
|
+
obs_batch_size : Int
|
|
950
|
+
The size of the batch of randomly selected points among
|
|
951
|
+
the `n` points. `obs_batch_size` will be the same for all
|
|
952
|
+
elements of the return observation dict batch.
|
|
953
|
+
NOTE: no check is done BUT users should be careful that
|
|
954
|
+
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
955
|
+
`omega_batch_size` or the product of both. In the first case, the
|
|
956
|
+
present DataGeneratorObservations instance complements an ODEBatch,
|
|
957
|
+
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
958
|
+
= False). In the second case, `obs_batch_size` =
|
|
959
|
+
`temporal_batch_size * omega_batch_size` if the present
|
|
960
|
+
DataGeneratorParameter complements a PDENonStatioBatch
|
|
961
|
+
with self.cartesian_product = True
|
|
962
|
+
observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
|
|
963
|
+
Observed values corresponding to the input of the PINN
|
|
964
|
+
(eg. the time at which we recorded the observations). The first
|
|
965
|
+
dimension must corresponds to the number of observed_values.
|
|
966
|
+
The second dimension depends on the input dimension of the PINN,
|
|
967
|
+
that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
|
|
968
|
+
for non-stationnary PDE.
|
|
969
|
+
observed_values : Float[Array, "n_obs, nb_pinn_out"]
|
|
970
|
+
Observed values that the PINN should learn to fit. The first
|
|
971
|
+
dimension must be aligned with observed_pinn_in.
|
|
972
|
+
observed_eq_params : Dict[str, Float[Array, "n_obs 1"]], default={}
|
|
973
|
+
A dict with keys corresponding to
|
|
974
|
+
the parameter name. The keys must match the keys in
|
|
975
|
+
`params["eq_params"]`. The values are jnp.array with 2 dimensions
|
|
976
|
+
with values corresponding to the parameter value for which we also
|
|
977
|
+
have observed_pinn_in and observed_values. Hence the first
|
|
978
|
+
dimension must be aligned with observed_pinn_in and observed_values.
|
|
979
|
+
Optional argument.
|
|
980
|
+
sharding_device : jax.sharding.Sharding, default=None
|
|
981
|
+
Default None. An optional sharding object to constraint the storage
|
|
982
|
+
of observed inputs, values and parameters. Typically, a
|
|
983
|
+
SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
|
|
984
|
+
datasets of observations. Note that computations for **batches**
|
|
985
|
+
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
986
|
+
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
987
|
+
arguments of `jinns.solve()`. Read the docs for more info.
|
|
988
|
+
"""
|
|
989
|
+
|
|
990
|
+
key: Key
|
|
991
|
+
obs_batch_size: Int = eqx.field(static=True)
|
|
992
|
+
observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
|
|
993
|
+
observed_values: Float[Array, "n_obs nb_pinn_out"]
|
|
994
|
+
observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
|
|
995
|
+
static=True, default_factory=lambda: {}
|
|
996
|
+
)
|
|
997
|
+
sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
|
|
998
|
+
|
|
999
|
+
n: Int = eqx.field(init=False)
|
|
1000
|
+
curr_idx: Int = eqx.field(init=False)
|
|
1001
|
+
indices: Array = eqx.field(init=False)
|
|
1002
|
+
|
|
1003
|
+
def __post_init__(self):
|
|
1004
|
+
if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
|
|
1005
|
+
raise ValueError(
|
|
1006
|
+
"self.observed_pinn_in and self.observed_values must have same first axis"
|
|
1007
|
+
)
|
|
1008
|
+
for _, v in self.observed_eq_params.items():
|
|
1009
|
+
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
1010
|
+
raise ValueError(
|
|
1011
|
+
"self.observed_pinn_in and the values of"
|
|
1012
|
+
" self.observed_eq_params must have the same first axis"
|
|
1013
|
+
)
|
|
1014
|
+
if len(self.observed_pinn_in.shape) == 1:
|
|
1015
|
+
self.observed_pinn_in = self.observed_pinn_in[:, None]
|
|
1016
|
+
if len(self.observed_pinn_in.shape) > 2:
|
|
1017
|
+
raise ValueError("self.observed_pinn_in must have 2 dimensions")
|
|
1018
|
+
if len(self.observed_values.shape) == 1:
|
|
1019
|
+
self.observed_values = self.observed_values[:, None]
|
|
1020
|
+
if len(self.observed_values.shape) > 2:
|
|
1021
|
+
raise ValueError("self.observed_values must have 2 dimensions")
|
|
1022
|
+
for k, v in self.observed_eq_params.items():
|
|
1023
|
+
if len(v.shape) == 1:
|
|
1024
|
+
self.observed_eq_params[k] = v[:, None]
|
|
1025
|
+
if len(v.shape) > 2:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
"Each value of observed_eq_params must have 2 dimensions"
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
self.n = self.observed_pinn_in.shape[0]
|
|
1031
|
+
|
|
1032
|
+
if self.sharding_device is not None:
|
|
1033
|
+
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
1034
|
+
self.observed_pinn_in, self.sharding_device
|
|
1035
|
+
)
|
|
1036
|
+
self.observed_values = jax.lax.with_sharding_constraint(
|
|
1037
|
+
self.observed_values, self.sharding_device
|
|
1038
|
+
)
|
|
1039
|
+
self.observed_eq_params = jax.lax.with_sharding_constraint(
|
|
1040
|
+
self.observed_eq_params, self.sharding_device
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
self.curr_idx = jnp.iinfo(jnp.int32).max - self.obs_batch_size - 1
|
|
1044
|
+
# For speed and to avoid duplicating data what is really
|
|
1045
|
+
# shuffled is a vector of indices
|
|
1046
|
+
if self.sharding_device is not None:
|
|
1047
|
+
self.indices = jax.lax.with_sharding_constraint(
|
|
1048
|
+
jnp.arange(self.n), self.sharding_device
|
|
1049
|
+
)
|
|
1050
|
+
else:
|
|
1051
|
+
self.indices = jnp.arange(self.n)
|
|
1052
|
+
|
|
1053
|
+
# recall post_init is the only place with _init_ where we can set
|
|
1054
|
+
# self attribute in a in-place way
|
|
1055
|
+
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
1056
|
+
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
1057
|
+
|
|
1058
|
+
def _get_operands(self) -> tuple[Key, Int[Array, "n"], Int, Int, None]:
|
|
1059
|
+
return (
|
|
1060
|
+
self.key,
|
|
1061
|
+
self.indices,
|
|
1062
|
+
self.curr_idx,
|
|
1063
|
+
self.obs_batch_size,
|
|
1064
|
+
None,
|
|
1083
1065
|
)
|
|
1084
|
-
aux_data = {
|
|
1085
|
-
k: vars(self)[k]
|
|
1086
|
-
for k in [
|
|
1087
|
-
"n",
|
|
1088
|
-
"nb",
|
|
1089
|
-
"nt",
|
|
1090
|
-
"omega_batch_size",
|
|
1091
|
-
"omega_border_batch_size",
|
|
1092
|
-
"temporal_batch_size",
|
|
1093
|
-
"method",
|
|
1094
|
-
"dim",
|
|
1095
|
-
"rar_parameters",
|
|
1096
|
-
"n_start",
|
|
1097
|
-
"nt_start",
|
|
1098
|
-
"cartesian_product",
|
|
1099
|
-
]
|
|
1100
|
-
}
|
|
1101
|
-
return (children, aux_data)
|
|
1102
1066
|
|
|
1103
|
-
|
|
1104
|
-
|
|
1067
|
+
def obs_batch(
|
|
1068
|
+
self,
|
|
1069
|
+
) -> tuple[
|
|
1070
|
+
"DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
|
|
1071
|
+
]:
|
|
1105
1072
|
"""
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1073
|
+
Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
|
|
1074
|
+
inputs), (obs, a mini batch of corresponding observations), (eq_params,
|
|
1075
|
+
a dictionary with entry names found in `params["eq_params"]` and values
|
|
1076
|
+
giving the correspond parameter value for the couple
|
|
1077
|
+
(input, observation) mentioned before).
|
|
1078
|
+
It can also be a dictionary of dictionaries as described above if
|
|
1079
|
+
observed_pinn_in, observed_values, etc. are dictionaries with keys
|
|
1080
|
+
representing the PINNs.
|
|
1110
1081
|
"""
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
p_times,
|
|
1124
|
-
p_omega,
|
|
1125
|
-
rar_iter_from_last_sampling,
|
|
1126
|
-
rar_iter_nb,
|
|
1127
|
-
) = children
|
|
1128
|
-
obj = cls(
|
|
1129
|
-
key=key,
|
|
1130
|
-
data_exists=True,
|
|
1131
|
-
min_pts=min_pts,
|
|
1132
|
-
max_pts=max_pts,
|
|
1133
|
-
tmin=tmin,
|
|
1134
|
-
tmax=tmax,
|
|
1135
|
-
**aux_data,
|
|
1082
|
+
|
|
1083
|
+
new_attributes = _reset_or_increment(
|
|
1084
|
+
self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
|
|
1085
|
+
)
|
|
1086
|
+
new = eqx.tree_at(
|
|
1087
|
+
lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
minib_indices = jax.lax.dynamic_slice(
|
|
1091
|
+
new.indices,
|
|
1092
|
+
start_indices=(new.curr_idx,),
|
|
1093
|
+
slice_sizes=(new.obs_batch_size,),
|
|
1136
1094
|
)
|
|
1137
|
-
obj.omega = omega
|
|
1138
|
-
obj.omega_border = omega_border
|
|
1139
|
-
obj.times = times
|
|
1140
|
-
obj.curr_omega_idx = curr_omega_idx
|
|
1141
|
-
obj.curr_omega_border_idx = curr_omega_border_idx
|
|
1142
|
-
obj.curr_time_idx = curr_time_idx
|
|
1143
|
-
obj.p_times = p_times
|
|
1144
|
-
obj.p_omega = p_omega
|
|
1145
|
-
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
1146
|
-
obj.rar_iter_nb = rar_iter_nb
|
|
1147
|
-
|
|
1148
|
-
return obj
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
@register_pytree_node_class
|
|
1152
|
-
class DataGeneratorParameter:
|
|
1153
|
-
"""
|
|
1154
|
-
A data generator for additional unidimensional parameter(s)
|
|
1155
|
-
"""
|
|
1156
1095
|
|
|
1157
|
-
|
|
1096
|
+
obs_batch = {
|
|
1097
|
+
"pinn_in": jnp.take(
|
|
1098
|
+
new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
|
|
1099
|
+
),
|
|
1100
|
+
"val": jnp.take(
|
|
1101
|
+
new.observed_values, minib_indices, unique_indices=True, axis=0
|
|
1102
|
+
),
|
|
1103
|
+
"eq_params": jax.tree_util.tree_map(
|
|
1104
|
+
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
1105
|
+
new.observed_eq_params,
|
|
1106
|
+
),
|
|
1107
|
+
}
|
|
1108
|
+
return new, obs_batch
|
|
1109
|
+
|
|
1110
|
+
def get_batch(
|
|
1158
1111
|
self,
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
param_ranges=None,
|
|
1163
|
-
method="grid",
|
|
1164
|
-
user_data=None,
|
|
1165
|
-
data_exists=False,
|
|
1166
|
-
):
|
|
1167
|
-
r"""
|
|
1168
|
-
Parameters
|
|
1169
|
-
----------
|
|
1170
|
-
key
|
|
1171
|
-
Jax random key to sample new time points and to shuffle batches
|
|
1172
|
-
or a dict of Jax random keys with key entries from param_ranges
|
|
1173
|
-
n
|
|
1174
|
-
An integer. The number of total points that will be divided in
|
|
1175
|
-
batches. Batches are made so that each data point is seen only
|
|
1176
|
-
once during 1 epoch.
|
|
1177
|
-
param_batch_size
|
|
1178
|
-
An integer. The size of the batch of randomly selected points among
|
|
1179
|
-
the `n` points. `param_batch_size` will be the same for all
|
|
1180
|
-
additional batch of parameter.
|
|
1181
|
-
NOTE: no check is done BUT users should be careful that
|
|
1182
|
-
`param_batch_size` must be equal to `temporal_batch_size` or
|
|
1183
|
-
`omega_batch_size` or the product of both. In the first case, the
|
|
1184
|
-
present DataGeneratorParameter instance complements an ODEBatch, a
|
|
1185
|
-
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1186
|
-
= False). In the second case, `param_batch_size` =
|
|
1187
|
-
`temporal_batch_size * omega_batch_size` if the present
|
|
1188
|
-
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1189
|
-
with self.cartesian_product = True
|
|
1190
|
-
param_ranges
|
|
1191
|
-
A dict. A dict of tuples (min, max), which
|
|
1192
|
-
reprensents the range of real numbers where to sample batches (of
|
|
1193
|
-
length `param_batch_size` among `n` points).
|
|
1194
|
-
The key corresponds to the parameter name. The keys must match the
|
|
1195
|
-
keys in `params["eq_params"]`.
|
|
1196
|
-
By providing several entries in this dictionary we can sample
|
|
1197
|
-
an arbitrary number of parameters.
|
|
1198
|
-
__Note__ that we currently only support unidimensional parameters
|
|
1199
|
-
method
|
|
1200
|
-
Either `grid` or `uniform`, default is `grid`. `grid` means
|
|
1201
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
1202
|
-
sampled points over the domain
|
|
1203
|
-
data_exists
|
|
1204
|
-
Must be left to `False` when created by the user. Avoids the
|
|
1205
|
-
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
1206
|
-
time points at each pytree flattening and unflattening.
|
|
1207
|
-
user_data
|
|
1208
|
-
A dictionary containing user-provided data for parameters.
|
|
1209
|
-
As for `param_ranges`, the key corresponds to the parameter name,
|
|
1210
|
-
the keys must match the keys in `params["eq_params"]` and only
|
|
1211
|
-
unidimensional arrays are supported. Therefore, the jnp arrays
|
|
1212
|
-
found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
1213
|
-
Note that if the same key appears in `param_ranges` and `user_data`
|
|
1214
|
-
priority goes for the content in `user_data`.
|
|
1215
|
-
Defaults to None.
|
|
1112
|
+
) -> tuple[
|
|
1113
|
+
"DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
|
|
1114
|
+
]:
|
|
1216
1115
|
"""
|
|
1217
|
-
|
|
1218
|
-
|
|
1116
|
+
Generic method to return a batch
|
|
1117
|
+
"""
|
|
1118
|
+
return self.obs_batch()
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
class DataGeneratorParameter(eqx.Module):
|
|
1122
|
+
r"""
|
|
1123
|
+
A data generator for additional unidimensional parameter(s)
|
|
1124
|
+
|
|
1125
|
+
Parameters
|
|
1126
|
+
----------
|
|
1127
|
+
keys : Key | Dict[str, Key]
|
|
1128
|
+
Jax random key to sample new time points and to shuffle batches
|
|
1129
|
+
or a dict of Jax random keys with key entries from param_ranges
|
|
1130
|
+
n : Int
|
|
1131
|
+
The number of total points that will be divided in
|
|
1132
|
+
batches. Batches are made so that each data point is seen only
|
|
1133
|
+
once during 1 epoch.
|
|
1134
|
+
param_batch_size : Int
|
|
1135
|
+
The size of the batch of randomly selected points among
|
|
1136
|
+
the `n` points. `param_batch_size` will be the same for all
|
|
1137
|
+
additional batch of parameter.
|
|
1138
|
+
NOTE: no check is done BUT users should be careful that
|
|
1139
|
+
`param_batch_size` must be equal to `temporal_batch_size` or
|
|
1140
|
+
`omega_batch_size` or the product of both. In the first case, the
|
|
1141
|
+
present DataGeneratorParameter instance complements an ODEBatch, a
|
|
1142
|
+
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1143
|
+
= False). In the second case, `param_batch_size` =
|
|
1144
|
+
`temporal_batch_size * omega_batch_size` if the present
|
|
1145
|
+
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1146
|
+
with self.cartesian_product = True
|
|
1147
|
+
param_ranges : Dict[str, tuple[Float, Float] | None, default={}
|
|
1148
|
+
A dict. A dict of tuples (min, max), which
|
|
1149
|
+
reprensents the range of real numbers where to sample batches (of
|
|
1150
|
+
length `param_batch_size` among `n` points).
|
|
1151
|
+
The key corresponds to the parameter name. The keys must match the
|
|
1152
|
+
keys in `params["eq_params"]`.
|
|
1153
|
+
By providing several entries in this dictionary we can sample
|
|
1154
|
+
an arbitrary number of parameters.
|
|
1155
|
+
**Note** that we currently only support unidimensional parameters.
|
|
1156
|
+
This argument can be done if we only use `user_data`.
|
|
1157
|
+
method : str, default="uniform"
|
|
1158
|
+
Either `grid` or `uniform`, default is `uniform`. `grid` means
|
|
1159
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
1160
|
+
sampled points over the domain
|
|
1161
|
+
user_data : Dict[str, Float[Array, "n"]] | None, default={}
|
|
1162
|
+
A dictionary containing user-provided data for parameters.
|
|
1163
|
+
As for `param_ranges`, the key corresponds to the parameter name,
|
|
1164
|
+
the keys must match the keys in `params["eq_params"]` and only
|
|
1165
|
+
unidimensional arrays are supported. Therefore, the jnp arrays
|
|
1166
|
+
found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
1167
|
+
Note that if the same key appears in `param_ranges` and `user_data`
|
|
1168
|
+
priority goes for the content in `user_data`.
|
|
1169
|
+
Defaults to None.
|
|
1170
|
+
"""
|
|
1171
|
+
|
|
1172
|
+
keys: Key | Dict[str, Key]
|
|
1173
|
+
n: Int
|
|
1174
|
+
param_batch_size: Int = eqx.field(static=True)
|
|
1175
|
+
param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
|
|
1176
|
+
static=True, default_factory=lambda: {}
|
|
1177
|
+
)
|
|
1178
|
+
method: str = eqx.field(static=True, default="uniform")
|
|
1179
|
+
user_data: Dict[str, Float[Array, "n"]] | None = eqx.field(
|
|
1180
|
+
static=True, default_factory=lambda: {}
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
curr_param_idx: Dict[str, Int] = eqx.field(init=False)
|
|
1184
|
+
param_n_samples: Dict[str, Array] = eqx.field(init=False)
|
|
1219
1185
|
|
|
1220
|
-
|
|
1186
|
+
def __post_init__(self):
|
|
1187
|
+
if self.user_data is None:
|
|
1188
|
+
self.user_data = {}
|
|
1189
|
+
if self.param_ranges is None:
|
|
1190
|
+
self.param_ranges = {}
|
|
1191
|
+
if self.n < self.param_batch_size:
|
|
1221
1192
|
raise ValueError(
|
|
1222
|
-
f"Number of data points ({n}) is smaller than the"
|
|
1223
|
-
f"number of batch points ({param_batch_size})."
|
|
1193
|
+
f"Number of data points ({self.n}) is smaller than the"
|
|
1194
|
+
f"number of batch points ({self.param_batch_size})."
|
|
1195
|
+
)
|
|
1196
|
+
if not isinstance(self.keys, dict):
|
|
1197
|
+
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1198
|
+
self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
|
|
1199
|
+
|
|
1200
|
+
self.curr_param_idx = {}
|
|
1201
|
+
for k in self.keys.keys():
|
|
1202
|
+
self.curr_param_idx[k] = (
|
|
1203
|
+
jnp.iinfo(jnp.int32).max - self.param_batch_size - 1
|
|
1224
1204
|
)
|
|
1225
1205
|
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
self.n = n
|
|
1236
|
-
self.param_batch_size = param_batch_size
|
|
1237
|
-
self.param_ranges = param_ranges
|
|
1238
|
-
self.user_data = user_data
|
|
1239
|
-
|
|
1240
|
-
if not self.data_exists:
|
|
1241
|
-
self.generate_data()
|
|
1242
|
-
# The previous call to self.generate_data() has created
|
|
1243
|
-
# the dict self.param_n_samples and then we will only use this one
|
|
1244
|
-
# because it has merged the scattered data between `user_data` and
|
|
1245
|
-
# `param_ranges`
|
|
1246
|
-
self.curr_param_idx = {}
|
|
1247
|
-
for k in self.param_n_samples.keys():
|
|
1248
|
-
self.curr_param_idx[k] = 0
|
|
1249
|
-
(
|
|
1250
|
-
self._keys[k],
|
|
1251
|
-
self.param_n_samples[k],
|
|
1252
|
-
_,
|
|
1253
|
-
) = _reset_batch_idx_and_permute(self._get_param_operands(k))
|
|
1254
|
-
|
|
1255
|
-
def generate_data(self):
|
|
1206
|
+
# The call to self.generate_data() creates
|
|
1207
|
+
# the dict self.param_n_samples and then we will only use this one
|
|
1208
|
+
# because it merges the scattered data between `user_data` and
|
|
1209
|
+
# `param_ranges`
|
|
1210
|
+
self.keys, self.param_n_samples = self.generate_data(self.keys)
|
|
1211
|
+
|
|
1212
|
+
def generate_data(
|
|
1213
|
+
self, keys: Dict[str, Key]
|
|
1214
|
+
) -> tuple[Dict[str, Key], Dict[str, Float[Array, "n"]]]:
|
|
1256
1215
|
"""
|
|
1257
1216
|
Generate parameter samples, either through generation
|
|
1258
1217
|
or using user-provided data.
|
|
1259
1218
|
"""
|
|
1260
|
-
|
|
1219
|
+
param_n_samples = {}
|
|
1261
1220
|
|
|
1262
1221
|
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1263
1222
|
for k in all_keys:
|
|
1264
|
-
if
|
|
1223
|
+
if (
|
|
1224
|
+
self.user_data
|
|
1225
|
+
and k in self.user_data.keys() # pylint: disable=no-member
|
|
1226
|
+
):
|
|
1265
1227
|
if self.user_data[k].shape == (self.n, 1):
|
|
1266
|
-
|
|
1228
|
+
param_n_samples[k] = self.user_data[k]
|
|
1267
1229
|
if self.user_data[k].shape == (self.n,):
|
|
1268
|
-
|
|
1230
|
+
param_n_samples[k] = self.user_data[k][:, None]
|
|
1269
1231
|
else:
|
|
1270
1232
|
raise ValueError(
|
|
1271
1233
|
"Wrong shape for user provided parameters"
|
|
@@ -1274,23 +1236,25 @@ class DataGeneratorParameter:
|
|
|
1274
1236
|
else:
|
|
1275
1237
|
if self.method == "grid":
|
|
1276
1238
|
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1277
|
-
|
|
1239
|
+
partial = (xmax - xmin) / self.n
|
|
1278
1240
|
# shape (n, 1)
|
|
1279
|
-
|
|
1280
|
-
:, None
|
|
1281
|
-
]
|
|
1241
|
+
param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
|
|
1282
1242
|
elif self.method == "uniform":
|
|
1283
1243
|
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1284
|
-
|
|
1285
|
-
|
|
1244
|
+
keys[k], subkey = jax.random.split(keys[k], 2)
|
|
1245
|
+
param_n_samples[k] = jax.random.uniform(
|
|
1286
1246
|
subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
|
|
1287
1247
|
)
|
|
1288
1248
|
else:
|
|
1289
1249
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
1290
1250
|
|
|
1291
|
-
|
|
1251
|
+
return keys, param_n_samples
|
|
1252
|
+
|
|
1253
|
+
def _get_param_operands(
|
|
1254
|
+
self, k: str
|
|
1255
|
+
) -> tuple[Key, Float[Array, "n"], Int, Int, None]:
|
|
1292
1256
|
return (
|
|
1293
|
-
self.
|
|
1257
|
+
self.keys[k],
|
|
1294
1258
|
self.param_n_samples[k],
|
|
1295
1259
|
self.curr_param_idx[k],
|
|
1296
1260
|
self.param_batch_size,
|
|
@@ -1315,26 +1279,28 @@ class DataGeneratorParameter:
|
|
|
1315
1279
|
_reset_or_increment_wrapper,
|
|
1316
1280
|
self.param_n_samples,
|
|
1317
1281
|
self.curr_param_idx,
|
|
1318
|
-
self.
|
|
1282
|
+
self.keys,
|
|
1319
1283
|
)
|
|
1320
1284
|
# we must transpose the pytrees because keys are merged in res
|
|
1321
1285
|
# https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
|
|
1322
|
-
(
|
|
1323
|
-
self.
|
|
1324
|
-
self.param_n_samples,
|
|
1325
|
-
self.curr_param_idx,
|
|
1326
|
-
) = jax.tree_util.tree_transpose(
|
|
1327
|
-
jax.tree_util.tree_structure(self._keys),
|
|
1286
|
+
new_attributes = jax.tree_util.tree_transpose(
|
|
1287
|
+
jax.tree_util.tree_structure(self.keys),
|
|
1328
1288
|
jax.tree_util.tree_structure([0, 0, 0]),
|
|
1329
1289
|
res,
|
|
1330
1290
|
)
|
|
1331
1291
|
|
|
1332
|
-
|
|
1292
|
+
new = eqx.tree_at(
|
|
1293
|
+
lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
|
|
1294
|
+
self,
|
|
1295
|
+
new_attributes,
|
|
1296
|
+
)
|
|
1297
|
+
|
|
1298
|
+
return new, jax.tree_util.tree_map(
|
|
1333
1299
|
lambda p, q: jax.lax.dynamic_slice(
|
|
1334
|
-
p, start_indices=(q, 0), slice_sizes=(
|
|
1300
|
+
p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
|
|
1335
1301
|
),
|
|
1336
|
-
|
|
1337
|
-
|
|
1302
|
+
new.param_n_samples,
|
|
1303
|
+
new.curr_param_idx,
|
|
1338
1304
|
)
|
|
1339
1305
|
|
|
1340
1306
|
def get_batch(self):
|
|
@@ -1343,251 +1309,9 @@ class DataGeneratorParameter:
|
|
|
1343
1309
|
"""
|
|
1344
1310
|
return self.param_batch()
|
|
1345
1311
|
|
|
1346
|
-
def tree_flatten(self):
|
|
1347
|
-
children = (
|
|
1348
|
-
self._keys,
|
|
1349
|
-
self.param_n_samples,
|
|
1350
|
-
self.curr_param_idx,
|
|
1351
|
-
)
|
|
1352
|
-
aux_data = {
|
|
1353
|
-
k: vars(self)[k]
|
|
1354
|
-
for k in ["n", "param_batch_size", "method", "param_ranges", "user_data"]
|
|
1355
|
-
}
|
|
1356
|
-
return (children, aux_data)
|
|
1357
|
-
|
|
1358
|
-
@classmethod
|
|
1359
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1360
|
-
(
|
|
1361
|
-
keys,
|
|
1362
|
-
param_n_samples,
|
|
1363
|
-
curr_param_idx,
|
|
1364
|
-
) = children
|
|
1365
|
-
obj = cls(
|
|
1366
|
-
key=keys,
|
|
1367
|
-
data_exists=True,
|
|
1368
|
-
**aux_data,
|
|
1369
|
-
)
|
|
1370
|
-
obj.param_n_samples = param_n_samples
|
|
1371
|
-
obj.curr_param_idx = curr_param_idx
|
|
1372
|
-
return obj
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
@register_pytree_node_class
|
|
1376
|
-
class DataGeneratorObservations:
|
|
1377
|
-
"""
|
|
1378
|
-
Despite the class name, it is rather a dataloader from user provided
|
|
1379
|
-
observations that will be used for the observations loss
|
|
1380
|
-
"""
|
|
1381
|
-
|
|
1382
|
-
def __init__(
|
|
1383
|
-
self,
|
|
1384
|
-
key,
|
|
1385
|
-
obs_batch_size,
|
|
1386
|
-
observed_pinn_in,
|
|
1387
|
-
observed_values,
|
|
1388
|
-
observed_eq_params=None,
|
|
1389
|
-
data_exists=False,
|
|
1390
|
-
sharding_device=None,
|
|
1391
|
-
):
|
|
1392
|
-
r"""
|
|
1393
|
-
Parameters
|
|
1394
|
-
----------
|
|
1395
|
-
key
|
|
1396
|
-
Jax random key to sample new time points and to shuffle batches
|
|
1397
|
-
obs_batch_size
|
|
1398
|
-
An integer. The size of the batch of randomly selected points among
|
|
1399
|
-
the `n` points. `obs_batch_size` will be the same for all
|
|
1400
|
-
elements of the obs dict.
|
|
1401
|
-
NOTE: no check is done BUT users should be careful that
|
|
1402
|
-
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
1403
|
-
`omega_batch_size` or the product of both. In the first case, the
|
|
1404
|
-
present DataGeneratorObservations instance complements an ODEBatch,
|
|
1405
|
-
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1406
|
-
= False). In the second case, `obs_batch_size` =
|
|
1407
|
-
`temporal_batch_size * omega_batch_size` if the present
|
|
1408
|
-
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1409
|
-
with self.cartesian_product = True
|
|
1410
|
-
observed_pinn_in
|
|
1411
|
-
A jnp.array with 2 dimensions.
|
|
1412
|
-
Observed values corresponding to the input of the PINN
|
|
1413
|
-
(eg. the time at which we recorded the observations). The first
|
|
1414
|
-
dimension must corresponds to the number of observed_values and
|
|
1415
|
-
observed_eq_params. The second dimension depends on the input dimension of the PINN, that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1` for non-stationnary PDE.
|
|
1416
|
-
observed_values
|
|
1417
|
-
A jnp.array with 2 dimensions.
|
|
1418
|
-
Observed values that the PINN should learn to fit. The first dimension must be aligned with observed_pinn_in and the values of observed_eq_params.
|
|
1419
|
-
observed_eq_params
|
|
1420
|
-
Optional. Default is None. A dict with keys corresponding to the
|
|
1421
|
-
parameter name. The keys must match the keys in
|
|
1422
|
-
`params["eq_params"]`. The values are jnp.array with 2 dimensions
|
|
1423
|
-
with values corresponding to the parameter value for which we also
|
|
1424
|
-
have observed_pinn_in and observed_values. Hence the first
|
|
1425
|
-
dimension must be aligned with observed_pinn_in and observed_values.
|
|
1426
|
-
data_exists
|
|
1427
|
-
Must be left to `False` when created by the user. Avoids the
|
|
1428
|
-
resetting of curr_idx at each pytree flattening and unflattening.
|
|
1429
|
-
sharding_device
|
|
1430
|
-
Default None. An optional sharding object to constraint the storage
|
|
1431
|
-
of observed inputs, values and parameters. Typically, a
|
|
1432
|
-
SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
|
|
1433
|
-
datasets of observations. Note that computations for **batches**
|
|
1434
|
-
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
1435
|
-
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
1436
|
-
arguments of `jinns.solve()`. Read the docs for more info.
|
|
1437
|
-
|
|
1438
|
-
"""
|
|
1439
|
-
if observed_eq_params is None:
|
|
1440
|
-
observed_eq_params = {}
|
|
1441
|
-
|
|
1442
|
-
if not data_exists:
|
|
1443
|
-
self.observed_eq_params = observed_eq_params.copy()
|
|
1444
|
-
else:
|
|
1445
|
-
# avoid copying when in flatten/unflatten
|
|
1446
|
-
self.observed_eq_params = observed_eq_params
|
|
1447
|
-
|
|
1448
|
-
if observed_pinn_in.shape[0] != observed_values.shape[0]:
|
|
1449
|
-
raise ValueError(
|
|
1450
|
-
"observed_pinn_in and observed_values must have same first axis"
|
|
1451
|
-
)
|
|
1452
|
-
for _, v in self.observed_eq_params.items():
|
|
1453
|
-
if v.shape[0] != observed_pinn_in.shape[0]:
|
|
1454
|
-
raise ValueError(
|
|
1455
|
-
"observed_pinn_in and the values of"
|
|
1456
|
-
" observed_eq_params must have the same first axis"
|
|
1457
|
-
)
|
|
1458
|
-
if len(observed_pinn_in.shape) == 1:
|
|
1459
|
-
observed_pinn_in = observed_pinn_in[:, None]
|
|
1460
|
-
if len(observed_pinn_in.shape) > 2:
|
|
1461
|
-
raise ValueError("observed_pinn_in must have 2 dimensions")
|
|
1462
|
-
if len(observed_values.shape) == 1:
|
|
1463
|
-
observed_values = observed_values[:, None]
|
|
1464
|
-
if len(observed_values.shape) > 2:
|
|
1465
|
-
raise ValueError("observed_values must have 2 dimensions")
|
|
1466
|
-
for k, v in self.observed_eq_params.items():
|
|
1467
|
-
if len(v.shape) == 1:
|
|
1468
|
-
self.observed_eq_params[k] = v[:, None]
|
|
1469
|
-
if len(v.shape) > 2:
|
|
1470
|
-
raise ValueError(
|
|
1471
|
-
"Each value of observed_eq_params must have 2 dimensions"
|
|
1472
|
-
)
|
|
1473
|
-
|
|
1474
|
-
self.n = observed_pinn_in.shape[0]
|
|
1475
|
-
self._key = key
|
|
1476
|
-
self.obs_batch_size = obs_batch_size
|
|
1477
|
-
|
|
1478
|
-
self.data_exists = data_exists
|
|
1479
|
-
if not self.data_exists and sharding_device is not None:
|
|
1480
|
-
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
1481
|
-
observed_pinn_in, sharding_device
|
|
1482
|
-
)
|
|
1483
|
-
self.observed_values = jax.lax.with_sharding_constraint(
|
|
1484
|
-
observed_values, sharding_device
|
|
1485
|
-
)
|
|
1486
|
-
self.observed_eq_params = jax.lax.with_sharding_constraint(
|
|
1487
|
-
self.observed_eq_params, sharding_device
|
|
1488
|
-
)
|
|
1489
|
-
else:
|
|
1490
|
-
self.observed_pinn_in = observed_pinn_in
|
|
1491
|
-
self.observed_values = observed_values
|
|
1492
|
-
|
|
1493
|
-
if not self.data_exists:
|
|
1494
|
-
self.curr_idx = 0
|
|
1495
|
-
# NOTE for speed and to avoid duplicating data what is really
|
|
1496
|
-
# shuffled is a vector of indices
|
|
1497
|
-
indices = jnp.arange(self.n)
|
|
1498
|
-
if sharding_device is not None:
|
|
1499
|
-
self.indices = jax.lax.with_sharding_constraint(
|
|
1500
|
-
indices, sharding_device
|
|
1501
|
-
)
|
|
1502
|
-
else:
|
|
1503
|
-
self.indices = indices
|
|
1504
|
-
self._key, self.indices, _ = _reset_batch_idx_and_permute(
|
|
1505
|
-
self._get_operands()
|
|
1506
|
-
)
|
|
1507
|
-
|
|
1508
|
-
def _get_operands(self):
|
|
1509
|
-
return (
|
|
1510
|
-
self._key,
|
|
1511
|
-
self.indices,
|
|
1512
|
-
self.curr_idx,
|
|
1513
|
-
self.obs_batch_size,
|
|
1514
|
-
None,
|
|
1515
|
-
)
|
|
1516
1312
|
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
|
|
1520
|
-
inputs), (obs, a mini batch of corresponding observations), (eq_params,
|
|
1521
|
-
a dictionary with entry names found in `params["eq_params"]` and values
|
|
1522
|
-
giving the correspond parameter value for the couple
|
|
1523
|
-
(input, observation) mentioned before).
|
|
1524
|
-
It can also be a dictionary of dictionaries as described above if
|
|
1525
|
-
observed_pinn_in, observed_values, etc. are dictionaries with keys
|
|
1526
|
-
representing the PINNs.
|
|
1527
|
-
"""
|
|
1528
|
-
|
|
1529
|
-
(self._key, self.indices, self.curr_idx) = _reset_or_increment(
|
|
1530
|
-
self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
|
|
1531
|
-
)
|
|
1532
|
-
|
|
1533
|
-
minib_indices = jax.lax.dynamic_slice(
|
|
1534
|
-
self.indices,
|
|
1535
|
-
start_indices=(self.curr_idx,),
|
|
1536
|
-
slice_sizes=(self.obs_batch_size,),
|
|
1537
|
-
)
|
|
1538
|
-
|
|
1539
|
-
obs_batch = {
|
|
1540
|
-
"pinn_in": jnp.take(
|
|
1541
|
-
self.observed_pinn_in, minib_indices, unique_indices=True, axis=0
|
|
1542
|
-
),
|
|
1543
|
-
"val": jnp.take(
|
|
1544
|
-
self.observed_values, minib_indices, unique_indices=True, axis=0
|
|
1545
|
-
),
|
|
1546
|
-
"eq_params": jax.tree_util.tree_map(
|
|
1547
|
-
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
1548
|
-
self.observed_eq_params,
|
|
1549
|
-
),
|
|
1550
|
-
}
|
|
1551
|
-
return obs_batch
|
|
1552
|
-
|
|
1553
|
-
def get_batch(self):
|
|
1554
|
-
"""
|
|
1555
|
-
Generic method to return a batch
|
|
1556
|
-
"""
|
|
1557
|
-
return self.obs_batch()
|
|
1558
|
-
|
|
1559
|
-
def tree_flatten(self):
|
|
1560
|
-
children = (self._key, self.curr_idx, self.indices)
|
|
1561
|
-
aux_data = {
|
|
1562
|
-
k: vars(self)[k]
|
|
1563
|
-
for k in [
|
|
1564
|
-
"obs_batch_size",
|
|
1565
|
-
"observed_pinn_in",
|
|
1566
|
-
"observed_values",
|
|
1567
|
-
"observed_eq_params",
|
|
1568
|
-
]
|
|
1569
|
-
}
|
|
1570
|
-
return (children, aux_data)
|
|
1571
|
-
|
|
1572
|
-
@classmethod
|
|
1573
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1574
|
-
(key, curr_idx, indices) = children
|
|
1575
|
-
obj = cls(
|
|
1576
|
-
key=key,
|
|
1577
|
-
data_exists=True,
|
|
1578
|
-
obs_batch_size=aux_data["obs_batch_size"],
|
|
1579
|
-
observed_pinn_in=aux_data["observed_pinn_in"],
|
|
1580
|
-
observed_values=aux_data["observed_values"],
|
|
1581
|
-
observed_eq_params=aux_data["observed_eq_params"],
|
|
1582
|
-
)
|
|
1583
|
-
obj.curr_idx = curr_idx
|
|
1584
|
-
obj.indices = indices
|
|
1585
|
-
return obj
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
@register_pytree_node_class
|
|
1589
|
-
class DataGeneratorObservationsMultiPINNs:
|
|
1590
|
-
"""
|
|
1313
|
+
class DataGeneratorObservationsMultiPINNs(eqx.Module):
|
|
1314
|
+
r"""
|
|
1591
1315
|
Despite the class name, it is rather a dataloader from user provided
|
|
1592
1316
|
observations that will be used for the observations loss.
|
|
1593
1317
|
This is the DataGenerator to use when dealing with multiple PINNs
|
|
@@ -1597,146 +1321,123 @@ class DataGeneratorObservationsMultiPINNs:
|
|
|
1597
1321
|
applied in `constraints_system_loss_apply` and in this case the
|
|
1598
1322
|
batch.obs_batch_dict is a dict of obs_batch_dict over which the tree_map
|
|
1599
1323
|
applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
|
|
1324
|
+
|
|
1325
|
+
Parameters
|
|
1326
|
+
----------
|
|
1327
|
+
obs_batch_size : Int
|
|
1328
|
+
The size of the batch of randomly selected observations
|
|
1329
|
+
`obs_batch_size` will be the same for all the
|
|
1330
|
+
elements of the obs dict.
|
|
1331
|
+
NOTE: no check is done BUT users should be careful that
|
|
1332
|
+
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
1333
|
+
`omega_batch_size` or the product of both. In the first case, the
|
|
1334
|
+
present DataGeneratorObservations instance complements an ODEBatch,
|
|
1335
|
+
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1336
|
+
= False). In the second case, `obs_batch_size` =
|
|
1337
|
+
`temporal_batch_size * omega_batch_size` if the present
|
|
1338
|
+
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1339
|
+
with self.cartesian_product = True
|
|
1340
|
+
observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
|
|
1341
|
+
A dict of observed_pinn_in as defined in DataGeneratorObservations.
|
|
1342
|
+
Keys must be that of `u_dict`.
|
|
1343
|
+
If no observation exists for a particular entry of `u_dict` the
|
|
1344
|
+
corresponding key must still exist in observed_pinn_in_dict with
|
|
1345
|
+
value None
|
|
1346
|
+
observed_values_dict : Dict[str, Float[Array, "n_obs, nb_pinn_out"] | None]
|
|
1347
|
+
A dict of observed_values as defined in DataGeneratorObservations.
|
|
1348
|
+
Keys must be that of `u_dict`.
|
|
1349
|
+
If no observation exists for a particular entry of `u_dict` the
|
|
1350
|
+
corresponding key must still exist in observed_values_dict with
|
|
1351
|
+
value None
|
|
1352
|
+
observed_eq_params_dict : Dict[str, Dict[str, Float[Array, "n_obs 1"]]]
|
|
1353
|
+
A dict of observed_eq_params as defined in DataGeneratorObservations.
|
|
1354
|
+
Keys must be that of `u_dict`.
|
|
1355
|
+
**Note**: if no observation exists for a particular entry of `u_dict` the
|
|
1356
|
+
corresponding key must still exist in observed_eq_params_dict with
|
|
1357
|
+
value `{}` (empty dictionnary).
|
|
1358
|
+
key
|
|
1359
|
+
Jax random key to shuffle batches.
|
|
1600
1360
|
"""
|
|
1601
1361
|
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
)
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
NOTE: no check is done BUT users should be careful that
|
|
1619
|
-
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
1620
|
-
`omega_batch_size` or the product of both. In the first case, the
|
|
1621
|
-
present DataGeneratorObservations instance complements an ODEBatch,
|
|
1622
|
-
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1623
|
-
= False). In the second case, `obs_batch_size` =
|
|
1624
|
-
`temporal_batch_size * omega_batch_size` if the present
|
|
1625
|
-
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1626
|
-
with self.cartesian_product = True
|
|
1627
|
-
observed_pinn_in_dict
|
|
1628
|
-
A dict of observed_pinn_in as defined in DataGeneratorObservations.
|
|
1629
|
-
Keys must be that of `u_dict`.
|
|
1630
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1631
|
-
corresponding key must still exist in observed_pinn_in_dict with
|
|
1632
|
-
value None
|
|
1633
|
-
observed_values_dict
|
|
1634
|
-
A dict of observed_values as defined in DataGeneratorObservations.
|
|
1635
|
-
Keys must be that of `u_dict`.
|
|
1636
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1637
|
-
corresponding key must still exist in observed_values_dict with
|
|
1638
|
-
value None
|
|
1639
|
-
observed_eq_params_dict
|
|
1640
|
-
A dict of observed_eq_params as defined in DataGeneratorObservations.
|
|
1641
|
-
Keys must be that of `u_dict`.
|
|
1642
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1643
|
-
corresponding key must still exist in observed_eq_params_dict with
|
|
1644
|
-
value None
|
|
1645
|
-
data_gen_obs_exists
|
|
1646
|
-
Must be left to `False` when created by the user. Avoids the
|
|
1647
|
-
regeneration the subclasses DataGeneratorObservations
|
|
1648
|
-
at each pytree flattening and unflattening.
|
|
1649
|
-
key
|
|
1650
|
-
Jax random key to sample new time points and to shuffle batches.
|
|
1651
|
-
Optional if data_gen_obs_exists is True
|
|
1652
|
-
"""
|
|
1653
|
-
if (
|
|
1654
|
-
observed_pinn_in_dict is None or observed_values_dict is None
|
|
1655
|
-
) and not data_gen_obs_exists:
|
|
1362
|
+
obs_batch_size: Int
|
|
1363
|
+
observed_pinn_in_dict: Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
|
|
1364
|
+
observed_values_dict: Dict[str, Float[Array, "n_obs nb_pinn_out"] | None]
|
|
1365
|
+
observed_eq_params_dict: Dict[str, Dict[str, Float[Array, "n_obs 1"]]] = eqx.field(
|
|
1366
|
+
default=None, kw_only=True
|
|
1367
|
+
)
|
|
1368
|
+
key: InitVar[Key]
|
|
1369
|
+
|
|
1370
|
+
data_gen_obs: Dict[str, "DataGeneratorObservations"] = eqx.field(init=False)
|
|
1371
|
+
|
|
1372
|
+
def __post_init__(self, key):
|
|
1373
|
+
if self.observed_pinn_in_dict is None or self.observed_values_dict is None:
|
|
1374
|
+
raise ValueError(
|
|
1375
|
+
"observed_pinn_in_dict and observed_values_dict " "must be provided"
|
|
1376
|
+
)
|
|
1377
|
+
if self.observed_pinn_in_dict.keys() != self.observed_values_dict.keys():
|
|
1656
1378
|
raise ValueError(
|
|
1657
|
-
"
|
|
1658
|
-
"
|
|
1379
|
+
"Keys must be the same in observed_pinn_in_dict"
|
|
1380
|
+
" and observed_values_dict"
|
|
1659
1381
|
)
|
|
1660
|
-
self.obs_batch_size = obs_batch_size
|
|
1661
|
-
self.data_gen_obs_exists = data_gen_obs_exists
|
|
1662
1382
|
|
|
1663
|
-
if
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
and observed_pinn_in_dict.keys() != observed_eq_params_dict.keys()
|
|
1672
|
-
):
|
|
1673
|
-
raise ValueError(
|
|
1674
|
-
"Keys must be the same in observed_eq_params_dict"
|
|
1675
|
-
" and observed_pinn_in_dict and observed_values_dict"
|
|
1676
|
-
)
|
|
1677
|
-
if observed_eq_params_dict is None:
|
|
1678
|
-
observed_eq_params_dict = {
|
|
1679
|
-
k: None for k in observed_pinn_in_dict.keys()
|
|
1680
|
-
}
|
|
1681
|
-
|
|
1682
|
-
keys = dict(
|
|
1683
|
-
zip(
|
|
1684
|
-
observed_pinn_in_dict.keys(),
|
|
1685
|
-
jax.random.split(key, len(observed_pinn_in_dict)),
|
|
1686
|
-
)
|
|
1383
|
+
if self.observed_eq_params_dict is None:
|
|
1384
|
+
self.observed_eq_params_dict = {
|
|
1385
|
+
k: {} for k in self.observed_pinn_in_dict.keys()
|
|
1386
|
+
}
|
|
1387
|
+
elif self.observed_pinn_in_dict.keys() != self.observed_eq_params_dict.keys():
|
|
1388
|
+
raise ValueError(
|
|
1389
|
+
f"Keys must be the same in observed_eq_params_dict"
|
|
1390
|
+
f" and observed_pinn_in_dict and observed_values_dict"
|
|
1687
1391
|
)
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
if pinn_in is not None
|
|
1694
|
-
else None
|
|
1695
|
-
),
|
|
1696
|
-
keys,
|
|
1697
|
-
observed_pinn_in_dict,
|
|
1698
|
-
observed_values_dict,
|
|
1699
|
-
observed_eq_params_dict,
|
|
1392
|
+
|
|
1393
|
+
keys = dict(
|
|
1394
|
+
zip(
|
|
1395
|
+
self.observed_pinn_in_dict.keys(),
|
|
1396
|
+
jax.random.split(key, len(self.observed_pinn_in_dict)),
|
|
1700
1397
|
)
|
|
1398
|
+
)
|
|
1399
|
+
self.data_gen_obs = jax.tree_util.tree_map(
|
|
1400
|
+
lambda k, pinn_in, val, eq_params: (
|
|
1401
|
+
DataGeneratorObservations(
|
|
1402
|
+
k, self.obs_batch_size, pinn_in, val, eq_params
|
|
1403
|
+
)
|
|
1404
|
+
if pinn_in is not None
|
|
1405
|
+
else None
|
|
1406
|
+
),
|
|
1407
|
+
keys,
|
|
1408
|
+
self.observed_pinn_in_dict,
|
|
1409
|
+
self.observed_values_dict,
|
|
1410
|
+
self.observed_eq_params_dict,
|
|
1411
|
+
)
|
|
1701
1412
|
|
|
1702
|
-
def obs_batch(self):
|
|
1413
|
+
def obs_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
|
|
1703
1414
|
"""
|
|
1704
1415
|
Returns a dictionary of DataGeneratorObservations.obs_batch with keys
|
|
1705
1416
|
from `u_dict`
|
|
1706
1417
|
"""
|
|
1707
|
-
|
|
1418
|
+
data_gen_and_batch_pytree = jax.tree_util.tree_map(
|
|
1708
1419
|
lambda a: a.get_batch() if a is not None else {},
|
|
1709
1420
|
self.data_gen_obs,
|
|
1710
1421
|
is_leaf=lambda x: isinstance(x, DataGeneratorObservations),
|
|
1711
1422
|
) # note the is_leaf note to traverse the DataGeneratorObservations and
|
|
1712
1423
|
# thus to be able to call the method on the element(s) of
|
|
1713
1424
|
# self.data_gen_obs which are not None
|
|
1425
|
+
new_attribute = jax.tree_util.tree_map(
|
|
1426
|
+
lambda a: a[0],
|
|
1427
|
+
data_gen_and_batch_pytree,
|
|
1428
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
1429
|
+
)
|
|
1430
|
+
new = eqx.tree_at(lambda m: m.data_gen_obs, self, new_attribute)
|
|
1431
|
+
batches = jax.tree_util.tree_map(
|
|
1432
|
+
lambda a: a[1],
|
|
1433
|
+
data_gen_and_batch_pytree,
|
|
1434
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
1435
|
+
)
|
|
1714
1436
|
|
|
1715
|
-
|
|
1437
|
+
return new, batches
|
|
1438
|
+
|
|
1439
|
+
def get_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
|
|
1716
1440
|
"""
|
|
1717
1441
|
Generic method to return a batch
|
|
1718
1442
|
"""
|
|
1719
1443
|
return self.obs_batch()
|
|
1720
|
-
|
|
1721
|
-
def tree_flatten(self):
|
|
1722
|
-
# because a dict with "str" keys cannot go in the children (jittable)
|
|
1723
|
-
# attributes, we need to separate it in two and recreate the zip in the
|
|
1724
|
-
# tree_unflatten
|
|
1725
|
-
children = self.data_gen_obs.values()
|
|
1726
|
-
aux_data = {
|
|
1727
|
-
"obs_batch_size": self.obs_batch_size,
|
|
1728
|
-
"data_gen_obs_keys": self.data_gen_obs.keys(),
|
|
1729
|
-
}
|
|
1730
|
-
return (children, aux_data)
|
|
1731
|
-
|
|
1732
|
-
@classmethod
|
|
1733
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1734
|
-
(data_gen_obs_values) = children
|
|
1735
|
-
obj = cls(
|
|
1736
|
-
observed_pinn_in_dict=None,
|
|
1737
|
-
observed_values_dict=None,
|
|
1738
|
-
data_gen_obs_exists=True,
|
|
1739
|
-
obs_batch_size=aux_data["obs_batch_size"],
|
|
1740
|
-
)
|
|
1741
|
-
obj.data_gen_obs = dict(zip(aux_data["data_gen_obs_keys"], data_gen_obs_values))
|
|
1742
|
-
return obj
|