jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/data/_DataGenerators.py
CHANGED
|
@@ -1,75 +1,93 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object
|
|
1
2
|
"""
|
|
2
|
-
|
|
3
|
+
Define the DataGeneratorODE equinox module
|
|
3
4
|
"""
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Dict
|
|
10
|
+
from dataclasses import InitVar
|
|
11
|
+
import equinox as eqx
|
|
12
|
+
import jax
|
|
7
13
|
import jax.numpy as jnp
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
import jax.lax
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ODEBatch(NamedTuple):
|
|
14
|
-
temporal_batch: ArrayLike
|
|
15
|
-
param_batch_dict: dict = None
|
|
16
|
-
obs_batch_dict: dict = None
|
|
17
|
-
|
|
14
|
+
from jaxtyping import Key, Int, PyTree, Array, Float, Bool
|
|
15
|
+
from jinns.data._Batchs import *
|
|
18
16
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
border_batch: ArrayLike
|
|
22
|
-
temporal_batch: ArrayLike
|
|
23
|
-
param_batch_dict: dict = None
|
|
24
|
-
obs_batch_dict: dict = None
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from jinns.utils._types import *
|
|
25
19
|
|
|
26
20
|
|
|
27
|
-
|
|
28
|
-
inside_batch: ArrayLike
|
|
29
|
-
border_batch: ArrayLike
|
|
30
|
-
param_batch_dict: dict = None
|
|
31
|
-
obs_batch_dict: dict = None
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def append_param_batch(batch, param_batch_dict):
|
|
21
|
+
def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
|
|
35
22
|
"""
|
|
36
23
|
Utility function that fill the param_batch_dict of a batch object with a
|
|
37
24
|
param_batch_dict
|
|
38
25
|
"""
|
|
39
|
-
return
|
|
26
|
+
return eqx.tree_at(
|
|
27
|
+
lambda m: m.param_batch_dict,
|
|
28
|
+
batch,
|
|
29
|
+
param_batch_dict,
|
|
30
|
+
is_leaf=lambda x: x is None,
|
|
31
|
+
)
|
|
40
32
|
|
|
41
33
|
|
|
42
|
-
def append_obs_batch(batch, obs_batch_dict):
|
|
34
|
+
def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
|
|
43
35
|
"""
|
|
44
36
|
Utility function that fill the obs_batch_dict of a batch object with a
|
|
45
37
|
obs_batch_dict
|
|
46
38
|
"""
|
|
47
|
-
return
|
|
39
|
+
return eqx.tree_at(
|
|
40
|
+
lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def make_cartesian_product(
|
|
45
|
+
b1: Float[Array, "batch_size dim1"], b2: Float[Array, "batch_size dim2"]
|
|
46
|
+
) -> Float[Array, "(batch_size*batch_size) (dim1+dim2)"]:
|
|
47
|
+
"""
|
|
48
|
+
Create the cartesian product of a time and a border omega batches
|
|
49
|
+
by tiling and repeating
|
|
50
|
+
"""
|
|
51
|
+
n1 = b1.shape[0]
|
|
52
|
+
n2 = b2.shape[0]
|
|
53
|
+
b1 = jnp.repeat(b1, n2, axis=0)
|
|
54
|
+
b2 = jnp.tile(b2, reps=(n1,) + tuple(1 for i in b2.shape[1:]))
|
|
55
|
+
return jnp.concatenate([b1, b2], axis=1)
|
|
48
56
|
|
|
49
57
|
|
|
50
|
-
def _reset_batch_idx_and_permute(
|
|
58
|
+
def _reset_batch_idx_and_permute(
|
|
59
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
|
|
60
|
+
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
51
61
|
key, domain, curr_idx, _, p = operands
|
|
52
62
|
# resetting counter
|
|
53
63
|
curr_idx = 0
|
|
54
64
|
# reshuffling
|
|
55
|
-
key, subkey = random.split(key)
|
|
65
|
+
key, subkey = jax.random.split(key)
|
|
56
66
|
# domain = random.permutation(subkey, domain, axis=0, independent=False)
|
|
57
67
|
# we want that permutation = choice when p=None
|
|
58
68
|
# otherwise p is used to avoid collocation points not in nt_start
|
|
59
|
-
domain = random.choice(
|
|
69
|
+
domain = jax.random.choice(
|
|
70
|
+
subkey, domain, shape=(domain.shape[0],), replace=False, p=p
|
|
71
|
+
)
|
|
60
72
|
|
|
61
73
|
# return updated
|
|
62
74
|
return (key, domain, curr_idx)
|
|
63
75
|
|
|
64
76
|
|
|
65
|
-
def _increment_batch_idx(
|
|
77
|
+
def _increment_batch_idx(
|
|
78
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
|
|
79
|
+
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
66
80
|
key, domain, curr_idx, batch_size, _ = operands
|
|
67
81
|
# simply increases counter and get the batch
|
|
68
82
|
curr_idx += batch_size
|
|
69
83
|
return (key, domain, curr_idx)
|
|
70
84
|
|
|
71
85
|
|
|
72
|
-
def _reset_or_increment(
|
|
86
|
+
def _reset_or_increment(
|
|
87
|
+
bend: Int,
|
|
88
|
+
n_eff: Int,
|
|
89
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
|
|
90
|
+
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
73
91
|
"""
|
|
74
92
|
Factorize the code of the jax.lax.cond which checks if we have seen all the
|
|
75
93
|
batches in an epoch
|
|
@@ -98,15 +116,18 @@ def _reset_or_increment(bend, n_eff, operands):
|
|
|
98
116
|
)
|
|
99
117
|
|
|
100
118
|
|
|
101
|
-
def _check_and_set_rar_parameters(
|
|
119
|
+
def _check_and_set_rar_parameters(
|
|
120
|
+
rar_parameters: dict, n: Int, n_start: Int
|
|
121
|
+
) -> tuple[Int, Float[Array, "n"], Int, Int]:
|
|
102
122
|
if rar_parameters is not None and n_start is None:
|
|
103
123
|
raise ValueError(
|
|
104
|
-
|
|
124
|
+
"nt_start must be provided in the context of RAR sampling scheme"
|
|
105
125
|
)
|
|
126
|
+
|
|
106
127
|
if rar_parameters is not None:
|
|
107
128
|
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
108
129
|
# probability to specify non-used collocation points (i.e. points
|
|
109
|
-
# above
|
|
130
|
+
# above nt_start). Thus, p is a vector of probability of shape (nt, 1).
|
|
110
131
|
p = jnp.zeros((n,))
|
|
111
132
|
p = p.at[:n_start].set(1 / n_start)
|
|
112
133
|
# set internal counter for the number of gradient steps since the
|
|
@@ -118,6 +139,7 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
|
|
|
118
139
|
# have been added
|
|
119
140
|
rar_iter_nb = 0
|
|
120
141
|
else:
|
|
142
|
+
n_start = n
|
|
121
143
|
p = None
|
|
122
144
|
rar_iter_from_last_sampling = None
|
|
123
145
|
rar_iter_nb = None
|
|
@@ -125,109 +147,102 @@ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
|
|
|
125
147
|
return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
|
|
126
148
|
|
|
127
149
|
|
|
128
|
-
|
|
129
|
-
# DataGenerator for ODE : only returns time_batches
|
|
130
|
-
#####################################################
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
@register_pytree_node_class
|
|
134
|
-
class DataGeneratorODE:
|
|
150
|
+
class DataGeneratorODE(eqx.Module):
|
|
135
151
|
"""
|
|
136
152
|
A class implementing data generator object for ordinary differential equations.
|
|
137
153
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
key : Key
|
|
157
|
+
Jax random key to sample new time points and to shuffle batches
|
|
158
|
+
nt : Int
|
|
159
|
+
The number of total time points that will be divided in
|
|
160
|
+
batches. Batches are made so that each data point is seen only
|
|
161
|
+
once during 1 epoch.
|
|
162
|
+
tmin : float
|
|
163
|
+
The minimum value of the time domain to consider
|
|
164
|
+
tmax : float
|
|
165
|
+
The maximum value of the time domain to consider
|
|
166
|
+
temporal_batch_size : int
|
|
167
|
+
The size of the batch of randomly selected points among
|
|
168
|
+
the `nt` points.
|
|
169
|
+
method : str, default="uniform"
|
|
170
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
171
|
+
The method that generates the `nt` time points. `grid` means
|
|
172
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
173
|
+
sampled points over the domain
|
|
174
|
+
rar_parameters : Dict[str, Int], default=None
|
|
175
|
+
Default to None: do not use Residual Adaptative Resampling.
|
|
176
|
+
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
177
|
+
which we start the RAR sampling scheme (we first have a burn in
|
|
178
|
+
period). `update_rate`: the number of gradient steps taken between
|
|
179
|
+
each appending of collocation points in the RAR algo.
|
|
180
|
+
`sample_size`: the size of the sample from which we will select new
|
|
181
|
+
collocation points. `selected_sample_size_times`: the number of selected
|
|
182
|
+
points from the sample to be added to the current collocation
|
|
183
|
+
points
|
|
184
|
+
nt_start : Int, default=None
|
|
185
|
+
Defaults to None. The effective size of nt used at start time.
|
|
186
|
+
This value must be
|
|
187
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
188
|
+
nt_start = nt and this is hidden from the user.
|
|
189
|
+
In RAR, nt_start
|
|
190
|
+
then corresponds to the initial number of points we train the PINN.
|
|
141
191
|
"""
|
|
142
192
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
batches. Batches are made so that each data point is seen only
|
|
163
|
-
once during 1 epoch.
|
|
164
|
-
tmin
|
|
165
|
-
A float. The minimum value of the time domain to consider
|
|
166
|
-
tmax
|
|
167
|
-
A float. The maximum value of the time domain to consider
|
|
168
|
-
temporal_batch_size
|
|
169
|
-
An integer. The size of the batch of randomly selected points among
|
|
170
|
-
the `nt` points.
|
|
171
|
-
method
|
|
172
|
-
Either `grid` or `uniform`, default is `uniform`.
|
|
173
|
-
The method that generates the `nt` time points. `grid` means
|
|
174
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
175
|
-
sampled points over the domain
|
|
176
|
-
rar_parameters
|
|
177
|
-
Default to None: do not use Residual Adaptative Resampling.
|
|
178
|
-
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
179
|
-
which we start the RAR sampling scheme (we first have a burn in
|
|
180
|
-
period). `update_every`: the number of gradient steps taken between
|
|
181
|
-
each appending of collocation points in the RAR algo.
|
|
182
|
-
`sample_size_times`: the size of the sample from which we will select new
|
|
183
|
-
collocation points. `selected_sample_size_times`: the number of selected
|
|
184
|
-
points from the sample to be added to the current collocation
|
|
185
|
-
points
|
|
186
|
-
"DeepXDE: A deep learning library for solving differential
|
|
187
|
-
equations", L. Lu, SIAM Review, 2021
|
|
188
|
-
nt_start
|
|
189
|
-
Defaults to None. The effective size of nt used at start time.
|
|
190
|
-
This value must be
|
|
191
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
192
|
-
nt_start = nt and this is hidden from the user.
|
|
193
|
-
In RAR, nt_start
|
|
194
|
-
then corresponds to the initial number of points we train the PINN.
|
|
195
|
-
data_exists
|
|
196
|
-
Must be left to `False` when created by the user. Avoids the
|
|
197
|
-
regeneration of the `nt` time points at each pytree flattening and
|
|
198
|
-
unflattening.
|
|
199
|
-
"""
|
|
200
|
-
self.data_exists = data_exists
|
|
201
|
-
self._key = key
|
|
202
|
-
self.nt = nt
|
|
203
|
-
self.tmin = tmin
|
|
204
|
-
self.tmax = tmax
|
|
205
|
-
self.temporal_batch_size = temporal_batch_size
|
|
206
|
-
self.method = method
|
|
207
|
-
self.rar_parameters = rar_parameters
|
|
208
|
-
|
|
209
|
-
# Set-up for RAR (if used)
|
|
193
|
+
key: Key
|
|
194
|
+
nt: Int
|
|
195
|
+
tmin: Float
|
|
196
|
+
tmax: Float
|
|
197
|
+
temporal_batch_size: Int = eqx.field(static=True) # static cause used as a
|
|
198
|
+
# shape in jax.lax.dynamic_slice
|
|
199
|
+
method: str = eqx.field(static=True, default_factory=lambda: "uniform")
|
|
200
|
+
rar_parameters: Dict[str, Int] = None
|
|
201
|
+
nt_start: Int = eqx.field(static=True, default=None)
|
|
202
|
+
|
|
203
|
+
# all the init=False fields are set in __post_init__, even after a _replace
|
|
204
|
+
# or eqx.tree_at __post_init__ is called
|
|
205
|
+
p_times: Float[Array, "nt"] = eqx.field(init=False)
|
|
206
|
+
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
207
|
+
rar_iter_nb: Int = eqx.field(init=False)
|
|
208
|
+
curr_time_idx: Int = eqx.field(init=False)
|
|
209
|
+
times: Float[Array, "nt"] = eqx.field(init=False)
|
|
210
|
+
|
|
211
|
+
def __post_init__(self):
|
|
210
212
|
(
|
|
211
213
|
self.nt_start,
|
|
212
214
|
self.p_times,
|
|
213
215
|
self.rar_iter_from_last_sampling,
|
|
214
216
|
self.rar_iter_nb,
|
|
215
|
-
) = _check_and_set_rar_parameters(rar_parameters,
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
217
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
|
|
218
|
+
|
|
219
|
+
self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
|
|
220
|
+
# to be sure there is a
|
|
221
|
+
# shuffling at first get_batch() we do not call
|
|
222
|
+
# _reset_batch_idx_and_permute in __init__ or __post_init__ because it
|
|
223
|
+
# would return a copy of self and we have not investigate what would
|
|
224
|
+
# happen
|
|
225
|
+
# NOTE the (- self.temporal_batch_size - 1) because otherwise when computing
|
|
226
|
+
# `bend` we overflow the max int32 with unwanted behaviour
|
|
227
|
+
|
|
228
|
+
self.key, self.times = self.generate_time_data(self.key)
|
|
229
|
+
# Note that, here, in __init__ (and __post_init__), this is the
|
|
230
|
+
# only place where self assignment are authorized so we do the
|
|
231
|
+
# above way for the key. Note that one of the motivation to return the
|
|
232
|
+
# key from generate_*_data is to easily align key with legacy
|
|
233
|
+
# DataGenerators to use same unit tests
|
|
234
|
+
|
|
235
|
+
def sample_in_time_domain(
|
|
236
|
+
self, key: Key, sample_size: Int = None
|
|
237
|
+
) -> Float[Array, "nt"]:
|
|
238
|
+
return jax.random.uniform(
|
|
239
|
+
key,
|
|
240
|
+
(self.nt if sample_size is None else sample_size,),
|
|
241
|
+
minval=self.tmin,
|
|
242
|
+
maxval=self.tmax,
|
|
243
|
+
)
|
|
229
244
|
|
|
230
|
-
def generate_time_data(self):
|
|
245
|
+
def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
|
|
231
246
|
"""
|
|
232
247
|
Construct a complete set of `self.nt` time points according to the
|
|
233
248
|
specified `self.method`
|
|
@@ -235,24 +250,28 @@ class DataGeneratorODE:
|
|
|
235
250
|
Note that self.times has always size self.nt and not self.nt_start, even
|
|
236
251
|
in RAR scheme, we must allocate all the collocation points
|
|
237
252
|
"""
|
|
253
|
+
key, subkey = jax.random.split(self.key)
|
|
238
254
|
if self.method == "grid":
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
255
|
+
partial_times = (self.tmax - self.tmin) / self.nt
|
|
256
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)
|
|
257
|
+
if self.method == "uniform":
|
|
258
|
+
return key, self.sample_in_time_domain(subkey)
|
|
259
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
245
260
|
|
|
246
|
-
def _get_time_operands(
|
|
261
|
+
def _get_time_operands(
|
|
262
|
+
self,
|
|
263
|
+
) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
|
|
247
264
|
return (
|
|
248
|
-
self.
|
|
265
|
+
self.key,
|
|
249
266
|
self.times,
|
|
250
267
|
self.curr_time_idx,
|
|
251
268
|
self.temporal_batch_size,
|
|
252
269
|
self.p_times,
|
|
253
270
|
)
|
|
254
271
|
|
|
255
|
-
def temporal_batch(
|
|
272
|
+
def temporal_batch(
|
|
273
|
+
self,
|
|
274
|
+
) -> tuple["DataGeneratorODE", Float[Array, "temporal_batch_size"]]:
|
|
256
275
|
"""
|
|
257
276
|
Return a batch of time points. If all the batches have been seen, we
|
|
258
277
|
reshuffle them, otherwise we just return the next unseen batch.
|
|
@@ -264,210 +283,142 @@ class DataGeneratorODE:
|
|
|
264
283
|
if self.rar_parameters is not None:
|
|
265
284
|
nt_eff = (
|
|
266
285
|
self.nt_start
|
|
267
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
286
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
|
|
268
287
|
)
|
|
269
288
|
else:
|
|
270
289
|
nt_eff = self.nt
|
|
271
|
-
|
|
272
|
-
|
|
290
|
+
|
|
291
|
+
new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
|
|
292
|
+
new = eqx.tree_at(
|
|
293
|
+
lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
|
|
273
294
|
)
|
|
274
295
|
|
|
275
296
|
# commands below are equivalent to
|
|
276
297
|
# return self.times[i:(i+t_batch_size)]
|
|
277
298
|
# start indices can be dynamic be the slice shape is fixed
|
|
278
|
-
return jax.lax.dynamic_slice(
|
|
279
|
-
|
|
280
|
-
start_indices=(
|
|
281
|
-
slice_sizes=(
|
|
299
|
+
return new, jax.lax.dynamic_slice(
|
|
300
|
+
new.times,
|
|
301
|
+
start_indices=(new.curr_time_idx,),
|
|
302
|
+
slice_sizes=(new.temporal_batch_size,),
|
|
282
303
|
)
|
|
283
304
|
|
|
284
|
-
def get_batch(self):
|
|
305
|
+
def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
|
|
285
306
|
"""
|
|
286
307
|
Generic method to return a batch. Here we call `self.temporal_batch()`
|
|
287
308
|
"""
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
def tree_flatten(self):
|
|
291
|
-
children = (
|
|
292
|
-
self._key,
|
|
293
|
-
self.times,
|
|
294
|
-
self.curr_time_idx,
|
|
295
|
-
self.tmin,
|
|
296
|
-
self.tmax,
|
|
297
|
-
self.p_times,
|
|
298
|
-
self.rar_iter_from_last_sampling,
|
|
299
|
-
self.rar_iter_nb,
|
|
300
|
-
) # arrays / dynamic values
|
|
301
|
-
aux_data = {
|
|
302
|
-
k: vars(self)[k]
|
|
303
|
-
for k in [
|
|
304
|
-
"temporal_batch_size",
|
|
305
|
-
"method",
|
|
306
|
-
"nt",
|
|
307
|
-
"rar_parameters",
|
|
308
|
-
"nt_start",
|
|
309
|
-
]
|
|
310
|
-
} # static values
|
|
311
|
-
return (children, aux_data)
|
|
312
|
-
|
|
313
|
-
@classmethod
|
|
314
|
-
def tree_unflatten(cls, aux_data, children):
|
|
315
|
-
"""
|
|
316
|
-
**Note:** When reconstructing the class, we force ``data_exists=True``
|
|
317
|
-
in order not to re-generate the data at each flattening and
|
|
318
|
-
unflattening that happens e.g. during the gradient descent in the
|
|
319
|
-
optimization process
|
|
320
|
-
"""
|
|
321
|
-
(
|
|
322
|
-
key,
|
|
323
|
-
times,
|
|
324
|
-
curr_time_idx,
|
|
325
|
-
tmin,
|
|
326
|
-
tmax,
|
|
327
|
-
p_times,
|
|
328
|
-
rar_iter_from_last_sampling,
|
|
329
|
-
rar_iter_nb,
|
|
330
|
-
) = children
|
|
331
|
-
obj = cls(
|
|
332
|
-
key=key,
|
|
333
|
-
data_exists=True,
|
|
334
|
-
tmin=tmin,
|
|
335
|
-
tmax=tmax,
|
|
336
|
-
**aux_data,
|
|
337
|
-
)
|
|
338
|
-
obj.times = times
|
|
339
|
-
obj.curr_time_idx = curr_time_idx
|
|
340
|
-
obj.p_times = p_times
|
|
341
|
-
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
342
|
-
obj.rar_iter_nb = rar_iter_nb
|
|
343
|
-
return obj
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
##########################################
|
|
347
|
-
# Data Generator for PDE in stationnary
|
|
348
|
-
# and non-stationnary cases
|
|
349
|
-
##########################################
|
|
309
|
+
new, temporal_batch = self.temporal_batch()
|
|
310
|
+
return new, ODEBatch(temporal_batch=temporal_batch)
|
|
350
311
|
|
|
351
312
|
|
|
352
|
-
class
|
|
353
|
-
"""
|
|
354
|
-
|
|
355
|
-
def __init__(self, data_exists=False) -> None:
|
|
356
|
-
# /!\ WARNING /!\: an-end user should never create an object
|
|
357
|
-
# with data_exists=True. Or else generate_data() won't be called.
|
|
358
|
-
# Useful when using a lax.scan with a DataGenerator in the carry
|
|
359
|
-
# It tells JAX not to re-generate data in the __init__()
|
|
360
|
-
self.data_exists = data_exists
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
@register_pytree_node_class
|
|
364
|
-
class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
365
|
-
"""
|
|
313
|
+
class CubicMeshPDEStatio(eqx.Module):
|
|
314
|
+
r"""
|
|
366
315
|
A class implementing data generator object for stationary partial
|
|
367
316
|
differential equations.
|
|
368
317
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
key : Key
|
|
321
|
+
Jax random key to sample new time points and to shuffle batches
|
|
322
|
+
n : Int
|
|
323
|
+
The number of total $\Omega$ points that will be divided in
|
|
324
|
+
batches. Batches are made so that each data point is seen only
|
|
325
|
+
once during 1 epoch.
|
|
326
|
+
nb : Int | None
|
|
327
|
+
The total number of points in $\partial\Omega$.
|
|
328
|
+
Can be `None` not to lose performance generating the border
|
|
329
|
+
batch if they are not used
|
|
330
|
+
omega_batch_size : Int
|
|
331
|
+
The size of the batch of randomly selected points among
|
|
332
|
+
the `n` points.
|
|
333
|
+
omega_border_batch_size : Int | None
|
|
334
|
+
The size of the batch of points randomly selected
|
|
335
|
+
among the `nb` points.
|
|
336
|
+
Can be `None` not to lose performance generating the border
|
|
337
|
+
batch if they are not used
|
|
338
|
+
dim : Int
|
|
339
|
+
Dimension of $\Omega$ domain
|
|
340
|
+
min_pts : tuple[tuple[Float, Float], ...]
|
|
341
|
+
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
342
|
+
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
343
|
+
x_{n, min})$
|
|
344
|
+
max_pts : tuple[tuple[Float, Float], ...]
|
|
345
|
+
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
346
|
+
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
347
|
+
x_{n,max})$
|
|
348
|
+
method : str, default="uniform"
|
|
349
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
350
|
+
The method that generates the `nt` time points. `grid` means
|
|
351
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
352
|
+
sampled points over the domain
|
|
353
|
+
rar_parameters : Dict[str, Int], default=None
|
|
354
|
+
Default to None: do not use Residual Adaptative Resampling.
|
|
355
|
+
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
356
|
+
which we start the RAR sampling scheme (we first have a burn in
|
|
357
|
+
period). `update_every`: the number of gradient steps taken between
|
|
358
|
+
each appending of collocation points in the RAR algo.
|
|
359
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
360
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
361
|
+
points from the sample to be added to the current collocation
|
|
362
|
+
points
|
|
363
|
+
n_start : Int, default=None
|
|
364
|
+
Defaults to None. The effective size of n used at start time.
|
|
365
|
+
This value must be
|
|
366
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
367
|
+
n_start = n and this is hidden from the user.
|
|
368
|
+
In RAR, n_start
|
|
369
|
+
then corresponds to the initial number of points we train the PINN.
|
|
372
370
|
"""
|
|
373
371
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
dim
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
414
|
-
in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
|
|
415
|
-
x_{n, min})`
|
|
416
|
-
max_pts
|
|
417
|
-
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
418
|
-
in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
|
|
419
|
-
x_{n,max})`
|
|
420
|
-
method
|
|
421
|
-
Either `grid` or `uniform`, default is `grid`.
|
|
422
|
-
The method that generates the `nt` time points. `grid` means
|
|
423
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
424
|
-
sampled points over the domain
|
|
425
|
-
rar_parameters
|
|
426
|
-
Default to None: do not use Residual Adaptative Resampling.
|
|
427
|
-
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
428
|
-
which we start the RAR sampling scheme (we first have a burn in
|
|
429
|
-
period). `update_every`: the number of gradient steps taken between
|
|
430
|
-
each appending of collocation points in the RAR algo.
|
|
431
|
-
`sample_size_omega`: the size of the sample from which we will select new
|
|
432
|
-
collocation points. `selected_sample_size_omega`: the number of selected
|
|
433
|
-
points from the sample to be added to the current collocation
|
|
434
|
-
points
|
|
435
|
-
"DeepXDE: A deep learning library for solving differential
|
|
436
|
-
equations", L. Lu, SIAM Review, 2021
|
|
437
|
-
n_start
|
|
438
|
-
Defaults to None. The effective size of n used at start time.
|
|
439
|
-
This value must be
|
|
440
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
441
|
-
n_start = n and this is hidden from the user.
|
|
442
|
-
In RAR, n_start
|
|
443
|
-
then corresponds to the initial number of points we train the PINN.
|
|
444
|
-
data_exists
|
|
445
|
-
Must be left to `False` when created by the user. Avoids the
|
|
446
|
-
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
447
|
-
time points at each pytree flattening and unflattening.
|
|
448
|
-
"""
|
|
449
|
-
super().__init__(data_exists=data_exists)
|
|
450
|
-
self.method = method
|
|
451
|
-
self._key = key
|
|
452
|
-
self.dim = dim
|
|
453
|
-
self.min_pts = min_pts
|
|
454
|
-
self.max_pts = max_pts
|
|
455
|
-
assert dim == len(min_pts) and isinstance(min_pts, tuple)
|
|
456
|
-
assert dim == len(max_pts) and isinstance(max_pts, tuple)
|
|
457
|
-
self.n = n
|
|
458
|
-
self.rar_parameters = rar_parameters
|
|
372
|
+
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
373
|
+
key: Key = eqx.field(kw_only=True)
|
|
374
|
+
n: Int = eqx.field(kw_only=True)
|
|
375
|
+
nb: Int | None = eqx.field(kw_only=True)
|
|
376
|
+
omega_batch_size: Int = eqx.field(
|
|
377
|
+
kw_only=True, static=True
|
|
378
|
+
) # static cause used as a
|
|
379
|
+
# shape in jax.lax.dynamic_slice
|
|
380
|
+
omega_border_batch_size: Int | None = eqx.field(
|
|
381
|
+
kw_only=True, static=True
|
|
382
|
+
) # static cause used as a
|
|
383
|
+
# shape in jax.lax.dynamic_slice
|
|
384
|
+
dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
|
|
385
|
+
# shape in jax.lax.dynamic_slice
|
|
386
|
+
min_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
|
|
387
|
+
max_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
|
|
388
|
+
method: str = eqx.field(
|
|
389
|
+
kw_only=True, static=True, default_factory=lambda: "uniform"
|
|
390
|
+
)
|
|
391
|
+
rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
|
|
392
|
+
n_start: Int = eqx.field(kw_only=True, default=None, static=True)
|
|
393
|
+
|
|
394
|
+
# all the init=False fields are set in __post_init__, even after a _replace
|
|
395
|
+
# or eqx.tree_at __post_init__ is called
|
|
396
|
+
p_omega: Float[Array, "n"] = eqx.field(init=False)
|
|
397
|
+
p_border: None = eqx.field(init=False)
|
|
398
|
+
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
399
|
+
rar_iter_nb: Int = eqx.field(init=False)
|
|
400
|
+
curr_omega_idx: Int = eqx.field(init=False)
|
|
401
|
+
curr_omega_border_idx: Int = eqx.field(init=False)
|
|
402
|
+
omega: Float[Array, "n dim"] = eqx.field(init=False)
|
|
403
|
+
omega_border: Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None = eqx.field(
|
|
404
|
+
init=False
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
def __post_init__(self):
|
|
408
|
+
assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
|
|
409
|
+
assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
|
|
410
|
+
|
|
459
411
|
(
|
|
460
412
|
self.n_start,
|
|
461
413
|
self.p_omega,
|
|
462
414
|
self.rar_iter_from_last_sampling,
|
|
463
415
|
self.rar_iter_nb,
|
|
464
|
-
) = _check_and_set_rar_parameters(rar_parameters, n
|
|
416
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
465
417
|
|
|
466
418
|
self.p_border = None # no RAR sampling for border for now
|
|
467
419
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
if omega_border_batch_size is None:
|
|
420
|
+
# Special handling for the border batch
|
|
421
|
+
if self.omega_border_batch_size is None:
|
|
471
422
|
self.nb = None
|
|
472
423
|
self.omega_border_batch_size = None
|
|
473
424
|
elif self.dim == 1:
|
|
@@ -476,53 +427,50 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
476
427
|
# always set to 2.
|
|
477
428
|
self.nb = 2
|
|
478
429
|
self.omega_border_batch_size = 2
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
# )
|
|
430
|
+
# We are in 1-D case => omega_border_batch_size is
|
|
431
|
+
# ignored since borders of Omega are singletons.
|
|
432
|
+
# self.border_batch() will return [xmin, xmax]
|
|
483
433
|
else:
|
|
484
|
-
if nb % (2 * self.dim) != 0 or nb < 2 * self.dim:
|
|
434
|
+
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
485
435
|
raise ValueError(
|
|
486
436
|
"number of border point must be"
|
|
487
437
|
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
488
438
|
)
|
|
489
|
-
if nb // (2 * self.dim) < omega_border_batch_size:
|
|
439
|
+
if self.nb // (2 * self.dim) < self.omega_border_batch_size:
|
|
490
440
|
raise ValueError(
|
|
491
441
|
"number of points per facets (nb//2*self.dim)"
|
|
492
442
|
" cannot be lower than border batch size"
|
|
493
443
|
)
|
|
494
|
-
self.nb = int((2 * self.dim) * (nb // (2 * self.dim)))
|
|
495
|
-
self.omega_border_batch_size = omega_border_batch_size
|
|
496
|
-
|
|
497
|
-
if not self.data_exists:
|
|
498
|
-
# Useful when using a lax.scan with pytree
|
|
499
|
-
# Optionally tells JAX not to re-generate data when re-building the
|
|
500
|
-
# object
|
|
501
|
-
self.curr_omega_idx = 0
|
|
502
|
-
self.curr_omega_border_idx = 0
|
|
503
|
-
self.generate_data()
|
|
504
|
-
self._key, self.omega, _ = _reset_batch_idx_and_permute(
|
|
505
|
-
self._get_omega_operands()
|
|
506
|
-
)
|
|
507
|
-
if self.omega_border is not None and self.dim > 1:
|
|
508
|
-
self._key, self.omega_border, _ = _reset_batch_idx_and_permute(
|
|
509
|
-
self._get_omega_border_operands()
|
|
510
|
-
)
|
|
444
|
+
self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
|
|
511
445
|
|
|
512
|
-
|
|
446
|
+
self.curr_omega_idx = jnp.iinfo(jnp.int32).max - self.omega_batch_size - 1
|
|
447
|
+
# see explaination in DataGeneratorODE
|
|
448
|
+
if self.omega_border_batch_size is None:
|
|
449
|
+
self.curr_omega_border_idx = None
|
|
450
|
+
else:
|
|
451
|
+
self.curr_omega_border_idx = (
|
|
452
|
+
jnp.iinfo(jnp.int32).max - self.omega_border_batch_size - 1
|
|
453
|
+
)
|
|
454
|
+
# key, subkey = jax.random.split(self.key)
|
|
455
|
+
# self.key = key
|
|
456
|
+
self.key, self.omega, self.omega_border = self.generate_data(self.key)
|
|
457
|
+
# see explaination in DataGeneratorODE for the key
|
|
458
|
+
|
|
459
|
+
def sample_in_omega_domain(
|
|
460
|
+
self, keys: Key, sample_size: Int = None
|
|
461
|
+
) -> Float[Array, "n dim"]:
|
|
462
|
+
sample_size = self.n if sample_size is None else sample_size
|
|
513
463
|
if self.dim == 1:
|
|
514
464
|
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
subkey, shape=(n_samples, 1), minval=xmin, maxval=xmax
|
|
465
|
+
return jax.random.uniform(
|
|
466
|
+
keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
518
467
|
)
|
|
519
|
-
keys = random.split(
|
|
520
|
-
self._key = keys[0]
|
|
468
|
+
# keys = jax.random.split(key, self.dim)
|
|
521
469
|
return jnp.concatenate(
|
|
522
470
|
[
|
|
523
|
-
random.uniform(
|
|
524
|
-
keys[i
|
|
525
|
-
(
|
|
471
|
+
jax.random.uniform(
|
|
472
|
+
keys[i],
|
|
473
|
+
(sample_size, 1),
|
|
526
474
|
minval=self.min_pts[i],
|
|
527
475
|
maxval=self.max_pts[i],
|
|
528
476
|
)
|
|
@@ -531,7 +479,9 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
531
479
|
axis=-1,
|
|
532
480
|
)
|
|
533
481
|
|
|
534
|
-
def sample_in_omega_border_domain(
|
|
482
|
+
def sample_in_omega_border_domain(
|
|
483
|
+
self, keys: Key
|
|
484
|
+
) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
|
|
535
485
|
if self.omega_border_batch_size is None:
|
|
536
486
|
return None
|
|
537
487
|
if self.dim == 1:
|
|
@@ -543,15 +493,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
543
493
|
# TODO : find a general & efficient way to sample from the border
|
|
544
494
|
# (facets) of the hypercube in general dim.
|
|
545
495
|
|
|
546
|
-
facet_n =
|
|
547
|
-
keys = random.split(self._key, 5)
|
|
548
|
-
self._key = keys[0]
|
|
549
|
-
subkeys = keys[1:]
|
|
496
|
+
facet_n = self.nb // (2 * self.dim)
|
|
550
497
|
xmin = jnp.hstack(
|
|
551
498
|
[
|
|
552
499
|
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
553
|
-
random.uniform(
|
|
554
|
-
|
|
500
|
+
jax.random.uniform(
|
|
501
|
+
keys[0],
|
|
555
502
|
(facet_n, 1),
|
|
556
503
|
minval=self.min_pts[1],
|
|
557
504
|
maxval=self.max_pts[1],
|
|
@@ -561,8 +508,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
561
508
|
xmax = jnp.hstack(
|
|
562
509
|
[
|
|
563
510
|
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
564
|
-
random.uniform(
|
|
565
|
-
|
|
511
|
+
jax.random.uniform(
|
|
512
|
+
keys[1],
|
|
566
513
|
(facet_n, 1),
|
|
567
514
|
minval=self.min_pts[1],
|
|
568
515
|
maxval=self.max_pts[1],
|
|
@@ -571,8 +518,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
571
518
|
)
|
|
572
519
|
ymin = jnp.hstack(
|
|
573
520
|
[
|
|
574
|
-
random.uniform(
|
|
575
|
-
|
|
521
|
+
jax.random.uniform(
|
|
522
|
+
keys[2],
|
|
576
523
|
(facet_n, 1),
|
|
577
524
|
minval=self.min_pts[0],
|
|
578
525
|
maxval=self.max_pts[0],
|
|
@@ -582,8 +529,8 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
582
529
|
)
|
|
583
530
|
ymax = jnp.hstack(
|
|
584
531
|
[
|
|
585
|
-
random.uniform(
|
|
586
|
-
|
|
532
|
+
jax.random.uniform(
|
|
533
|
+
keys[3],
|
|
587
534
|
(facet_n, 1),
|
|
588
535
|
minval=self.min_pts[0],
|
|
589
536
|
maxval=self.max_pts[0],
|
|
@@ -597,54 +544,71 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
597
544
|
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
598
545
|
)
|
|
599
546
|
|
|
600
|
-
def generate_data(self)
|
|
547
|
+
def generate_data(self, key: Key) -> tuple[
|
|
548
|
+
Key,
|
|
549
|
+
Float[Array, "n dim"],
|
|
550
|
+
Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
|
|
551
|
+
]:
|
|
601
552
|
r"""
|
|
602
|
-
Construct a complete set of `self.n`
|
|
553
|
+
Construct a complete set of `self.n` $\Omega$ points according to the
|
|
603
554
|
specified `self.method`. Also constructs a complete set of `self.nb`
|
|
604
|
-
|
|
555
|
+
$\partial\Omega$ points if `self.omega_border_batch_size` is not
|
|
605
556
|
`None`. If the latter is `None` we set `self.omega_border` to `None`.
|
|
606
557
|
"""
|
|
607
|
-
|
|
608
558
|
# Generate Omega
|
|
609
559
|
if self.method == "grid":
|
|
610
560
|
if self.dim == 1:
|
|
611
561
|
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
612
|
-
|
|
562
|
+
partial = (xmax - xmin) / self.n
|
|
613
563
|
# shape (n, 1)
|
|
614
|
-
|
|
564
|
+
omega = jnp.arange(xmin, xmax, partial)[:, None]
|
|
615
565
|
else:
|
|
616
|
-
|
|
566
|
+
partials = [
|
|
617
567
|
(self.max_pts[i] - self.min_pts[i]) / jnp.sqrt(self.n)
|
|
618
568
|
for i in range(self.dim)
|
|
619
569
|
]
|
|
620
570
|
xyz_ = jnp.meshgrid(
|
|
621
571
|
*[
|
|
622
|
-
jnp.arange(self.min_pts[i], self.max_pts[i],
|
|
572
|
+
jnp.arange(self.min_pts[i], self.max_pts[i], partials[i])
|
|
623
573
|
for i in range(self.dim)
|
|
624
574
|
]
|
|
625
575
|
)
|
|
626
576
|
xyz_ = [a.reshape((self.n, 1)) for a in xyz_]
|
|
627
|
-
|
|
577
|
+
omega = jnp.concatenate(xyz_, axis=-1)
|
|
628
578
|
elif self.method == "uniform":
|
|
629
|
-
self.
|
|
579
|
+
if self.dim == 1:
|
|
580
|
+
key, subkeys = jax.random.split(key, 2)
|
|
581
|
+
else:
|
|
582
|
+
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
583
|
+
omega = self.sample_in_omega_domain(subkeys)
|
|
630
584
|
else:
|
|
631
585
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
632
586
|
|
|
633
587
|
# Generate border of omega
|
|
634
|
-
self.
|
|
588
|
+
if self.dim == 2 and self.omega_border_batch_size is not None:
|
|
589
|
+
key, *subkeys = jax.random.split(key, 5)
|
|
590
|
+
else:
|
|
591
|
+
subkeys = None
|
|
592
|
+
omega_border = self.sample_in_omega_border_domain(subkeys)
|
|
593
|
+
|
|
594
|
+
return key, omega, omega_border
|
|
635
595
|
|
|
636
|
-
def _get_omega_operands(
|
|
596
|
+
def _get_omega_operands(
|
|
597
|
+
self,
|
|
598
|
+
) -> tuple[Key, Float[Array, "n dim"], Int, Int, Float[Array, "n"]]:
|
|
637
599
|
return (
|
|
638
|
-
self.
|
|
600
|
+
self.key,
|
|
639
601
|
self.omega,
|
|
640
602
|
self.curr_omega_idx,
|
|
641
603
|
self.omega_batch_size,
|
|
642
604
|
self.p_omega,
|
|
643
605
|
)
|
|
644
606
|
|
|
645
|
-
def inside_batch(
|
|
607
|
+
def inside_batch(
|
|
608
|
+
self,
|
|
609
|
+
) -> tuple["CubicMeshPDEStatio", Float[Array, "omega_batch_size dim"]]:
|
|
646
610
|
r"""
|
|
647
|
-
Return a batch of points in
|
|
611
|
+
Return a batch of points in $\Omega$.
|
|
648
612
|
If all the batches have been seen, we reshuffle them,
|
|
649
613
|
otherwise we just return the next unseen batch.
|
|
650
614
|
"""
|
|
@@ -660,38 +624,46 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
660
624
|
bstart = self.curr_omega_idx
|
|
661
625
|
bend = bstart + self.omega_batch_size
|
|
662
626
|
|
|
663
|
-
(
|
|
664
|
-
|
|
627
|
+
new_attributes = _reset_or_increment(bend, n_eff, self._get_omega_operands())
|
|
628
|
+
new = eqx.tree_at(
|
|
629
|
+
lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
|
|
665
630
|
)
|
|
666
631
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
start_indices=(self.curr_omega_idx, 0),
|
|
672
|
-
slice_sizes=(self.omega_batch_size, self.dim),
|
|
632
|
+
return new, jax.lax.dynamic_slice(
|
|
633
|
+
new.omega,
|
|
634
|
+
start_indices=(new.curr_omega_idx, 0),
|
|
635
|
+
slice_sizes=(new.omega_batch_size, new.dim),
|
|
673
636
|
)
|
|
674
637
|
|
|
675
|
-
def _get_omega_border_operands(
|
|
638
|
+
def _get_omega_border_operands(
|
|
639
|
+
self,
|
|
640
|
+
) -> tuple[
|
|
641
|
+
Key, Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None, Int, Int, None
|
|
642
|
+
]:
|
|
676
643
|
return (
|
|
677
|
-
self.
|
|
644
|
+
self.key,
|
|
678
645
|
self.omega_border,
|
|
679
646
|
self.curr_omega_border_idx,
|
|
680
647
|
self.omega_border_batch_size,
|
|
681
648
|
self.p_border,
|
|
682
649
|
)
|
|
683
650
|
|
|
684
|
-
def border_batch(
|
|
651
|
+
def border_batch(
|
|
652
|
+
self,
|
|
653
|
+
) -> tuple[
|
|
654
|
+
"CubicMeshPDEStatio",
|
|
655
|
+
Float[Array, "1 1 2"] | Float[Array, "omega_border_batch_size 2 4"] | None,
|
|
656
|
+
]:
|
|
685
657
|
r"""
|
|
686
658
|
Return
|
|
687
659
|
|
|
688
660
|
- The value `None` if `self.omega_border_batch_size` is `None`.
|
|
689
661
|
|
|
690
|
-
- a jnp array with two fixed values
|
|
662
|
+
- a jnp array with two fixed values $(x_{min}, x_{max})$ if
|
|
691
663
|
`self.dim` = 1. There is no sampling here, we return the entire
|
|
692
|
-
|
|
664
|
+
$\partial\Omega$
|
|
693
665
|
|
|
694
|
-
- a batch of points in
|
|
666
|
+
- a batch of points in $\partial\Omega$ otherwise, stacked by
|
|
695
667
|
facet on the last axis.
|
|
696
668
|
If all the batches have been seen, we reshuffle them,
|
|
697
669
|
otherwise we just return the next unseen batch.
|
|
@@ -699,229 +671,160 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
699
671
|
|
|
700
672
|
"""
|
|
701
673
|
if self.omega_border_batch_size is None:
|
|
702
|
-
return None
|
|
674
|
+
return self, None
|
|
703
675
|
if self.dim == 1:
|
|
704
676
|
# 1-D case, no randomness : we always return the whole omega border,
|
|
705
677
|
# i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
|
|
706
|
-
return self.omega_border[None, None] # shape is (1, 1, 2)
|
|
678
|
+
return self, self.omega_border[None, None] # shape is (1, 1, 2)
|
|
707
679
|
bstart = self.curr_omega_border_idx
|
|
708
680
|
bend = bstart + self.omega_border_batch_size
|
|
709
681
|
|
|
710
|
-
(
|
|
711
|
-
self.
|
|
712
|
-
self.omega_border,
|
|
713
|
-
self.curr_omega_border_idx,
|
|
714
|
-
) = _reset_or_increment(bend, self.nb, self._get_omega_border_operands())
|
|
715
|
-
|
|
716
|
-
# commands below are equivalent to
|
|
717
|
-
# return self.omega[i:(i+batch_size), 0:dim, 0:nb_facets]
|
|
718
|
-
# and nb_facets = 2 * dimension
|
|
719
|
-
# but JAX prefer the latter
|
|
720
|
-
return jax.lax.dynamic_slice(
|
|
721
|
-
self.omega_border,
|
|
722
|
-
start_indices=(self.curr_omega_border_idx, 0, 0),
|
|
723
|
-
slice_sizes=(self.omega_border_batch_size, self.dim, 2 * self.dim),
|
|
682
|
+
new_attributes = _reset_or_increment(
|
|
683
|
+
bend, self.nb, self._get_omega_border_operands()
|
|
724
684
|
)
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
and `self.border_batch()`
|
|
730
|
-
"""
|
|
731
|
-
return PDEStatioBatch(
|
|
732
|
-
inside_batch=self.inside_batch(), border_batch=self.border_batch()
|
|
685
|
+
new = eqx.tree_at(
|
|
686
|
+
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
|
|
687
|
+
self,
|
|
688
|
+
new_attributes,
|
|
733
689
|
)
|
|
734
690
|
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
self.omega_border,
|
|
740
|
-
self.curr_omega_idx,
|
|
741
|
-
self.curr_omega_border_idx,
|
|
742
|
-
self.min_pts,
|
|
743
|
-
self.max_pts,
|
|
744
|
-
self.p_omega,
|
|
745
|
-
self.rar_iter_from_last_sampling,
|
|
746
|
-
self.rar_iter_nb,
|
|
691
|
+
return new, jax.lax.dynamic_slice(
|
|
692
|
+
new.omega_border,
|
|
693
|
+
start_indices=(new.curr_omega_border_idx, 0, 0),
|
|
694
|
+
slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
|
|
747
695
|
)
|
|
748
|
-
aux_data = {
|
|
749
|
-
k: vars(self)[k]
|
|
750
|
-
for k in [
|
|
751
|
-
"n",
|
|
752
|
-
"nb",
|
|
753
|
-
"omega_batch_size",
|
|
754
|
-
"omega_border_batch_size",
|
|
755
|
-
"method",
|
|
756
|
-
"dim",
|
|
757
|
-
"rar_parameters",
|
|
758
|
-
"n_start",
|
|
759
|
-
]
|
|
760
|
-
}
|
|
761
|
-
return (children, aux_data)
|
|
762
696
|
|
|
763
|
-
|
|
764
|
-
def tree_unflatten(cls, aux_data, children):
|
|
697
|
+
def get_batch(self) -> tuple["CubicMeshPDEStatio", PDEStatioBatch]:
|
|
765
698
|
"""
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
unflattening that happens e.g. during the gradient descent in the
|
|
769
|
-
optimization process
|
|
699
|
+
Generic method to return a batch. Here we call `self.inside_batch()`
|
|
700
|
+
and `self.border_batch()`
|
|
770
701
|
"""
|
|
771
|
-
(
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
omega_border,
|
|
775
|
-
curr_omega_idx,
|
|
776
|
-
curr_omega_border_idx,
|
|
777
|
-
min_pts,
|
|
778
|
-
max_pts,
|
|
779
|
-
p_omega,
|
|
780
|
-
rar_iter_from_last_sampling,
|
|
781
|
-
rar_iter_nb,
|
|
782
|
-
) = children
|
|
783
|
-
# force data_exists=True here in order not to re-generate the data
|
|
784
|
-
# at each iteration of lax.scan
|
|
785
|
-
obj = cls(
|
|
786
|
-
key=key,
|
|
787
|
-
data_exists=True,
|
|
788
|
-
min_pts=min_pts,
|
|
789
|
-
max_pts=max_pts,
|
|
790
|
-
**aux_data,
|
|
791
|
-
)
|
|
792
|
-
obj.omega = omega
|
|
793
|
-
obj.omega_border = omega_border
|
|
794
|
-
obj.curr_omega_idx = curr_omega_idx
|
|
795
|
-
obj.curr_omega_border_idx = curr_omega_border_idx
|
|
796
|
-
obj.p_omega = p_omega
|
|
797
|
-
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
798
|
-
obj.rar_iter_nb = rar_iter_nb
|
|
799
|
-
return obj
|
|
702
|
+
new, inside_batch = self.inside_batch()
|
|
703
|
+
new, border_batch = new.border_batch()
|
|
704
|
+
return new, PDEStatioBatch(inside_batch=inside_batch, border_batch=border_batch)
|
|
800
705
|
|
|
801
706
|
|
|
802
|
-
@register_pytree_node_class
|
|
803
707
|
class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
804
|
-
"""
|
|
708
|
+
r"""
|
|
805
709
|
A class implementing data generator object for non stationary partial
|
|
806
710
|
differential equations. Formally, it extends `CubicMeshPDEStatio`
|
|
807
711
|
to include a temporal batch.
|
|
808
712
|
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
713
|
+
Parameters
|
|
714
|
+
----------
|
|
715
|
+
key : Key
|
|
716
|
+
Jax random key to sample new time points and to shuffle batches
|
|
717
|
+
n : Int
|
|
718
|
+
The number of total $\Omega$ points that will be divided in
|
|
719
|
+
batches. Batches are made so that each data point is seen only
|
|
720
|
+
once during 1 epoch.
|
|
721
|
+
nb : Int | None
|
|
722
|
+
The total number of points in $\partial\Omega$.
|
|
723
|
+
Can be `None` not to lose performance generating the border
|
|
724
|
+
batch if they are not used
|
|
725
|
+
nt : Int
|
|
726
|
+
The number of total time points that will be divided in
|
|
727
|
+
batches. Batches are made so that each data point is seen only
|
|
728
|
+
once during 1 epoch.
|
|
729
|
+
omega_batch_size : Int
|
|
730
|
+
The size of the batch of randomly selected points among
|
|
731
|
+
the `n` points.
|
|
732
|
+
omega_border_batch_size : Int | None
|
|
733
|
+
The size of the batch of points randomly selected
|
|
734
|
+
among the `nb` points.
|
|
735
|
+
Can be `None` not to lose performance generating the border
|
|
736
|
+
batch if they are not used
|
|
737
|
+
temporal_batch_size : Int
|
|
738
|
+
The size of the batch of randomly selected points among
|
|
739
|
+
the `nt` points.
|
|
740
|
+
dim : Int
|
|
741
|
+
An integer. dimension of $\Omega$ domain
|
|
742
|
+
min_pts : tuple[tuple[Float, Float], ...]
|
|
743
|
+
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
744
|
+
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
745
|
+
x_{n, min})$
|
|
746
|
+
max_pts : tuple[tuple[Float, Float], ...]
|
|
747
|
+
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
748
|
+
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
749
|
+
x_{n,max})$
|
|
750
|
+
tmin : float
|
|
751
|
+
The minimum value of the time domain to consider
|
|
752
|
+
tmax : float
|
|
753
|
+
The maximum value of the time domain to consider
|
|
754
|
+
method : str, default="uniform"
|
|
755
|
+
Either `grid` or `uniform`, default is `uniform`.
|
|
756
|
+
The method that generates the `nt` time points. `grid` means
|
|
757
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
758
|
+
sampled points over the domain
|
|
759
|
+
rar_parameters : Dict[str, Int], default=None
|
|
760
|
+
Default to None: do not use Residual Adaptative Resampling.
|
|
761
|
+
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
762
|
+
which we start the RAR sampling scheme (we first have a burn in
|
|
763
|
+
period). `update_every`: the number of gradient steps taken between
|
|
764
|
+
each appending of collocation points in the RAR algo.
|
|
765
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
766
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
767
|
+
points from the sample to be added to the current collocation
|
|
768
|
+
points.
|
|
769
|
+
n_start : Int, default=None
|
|
770
|
+
Defaults to None. The effective size of n used at start time.
|
|
771
|
+
This value must be
|
|
772
|
+
provided when rar_parameters is not None. Otherwise we set internally
|
|
773
|
+
n_start = n and this is hidden from the user.
|
|
774
|
+
In RAR, n_start
|
|
775
|
+
then corresponds to the initial number of omega points we train the PINN.
|
|
776
|
+
nt_start : Int, default=None
|
|
777
|
+
Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
|
|
778
|
+
for times collocation point. See also ``DataGeneratorODE``
|
|
779
|
+
documentation.
|
|
780
|
+
cartesian_product : Bool, default=True
|
|
781
|
+
Defaults to True. Whether we return the cartesian product of the
|
|
782
|
+
temporal batch with the inside and border batches. If False we just
|
|
783
|
+
return their concatenation.
|
|
812
784
|
"""
|
|
813
785
|
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
tmax,
|
|
828
|
-
method="grid",
|
|
829
|
-
rar_parameters=None,
|
|
830
|
-
n_start=None,
|
|
831
|
-
nt_start=None,
|
|
832
|
-
data_exists=False,
|
|
833
|
-
):
|
|
834
|
-
r"""
|
|
835
|
-
Parameters
|
|
836
|
-
----------
|
|
837
|
-
key
|
|
838
|
-
Jax random key to sample new time points and to shuffle batches
|
|
839
|
-
n
|
|
840
|
-
An integer. The number of total :math:`\Omega` points that will be divided in
|
|
841
|
-
batches. Batches are made so that each data point is seen only
|
|
842
|
-
once during 1 epoch.
|
|
843
|
-
nb
|
|
844
|
-
An integer. The total number of points in :math:`\partial\Omega`.
|
|
845
|
-
Can be `None` not to lose performance generating the border
|
|
846
|
-
batch if they are not used
|
|
847
|
-
nt
|
|
848
|
-
An integer. The number of total time points that will be divided in
|
|
849
|
-
batches. Batches are made so that each data point is seen only
|
|
850
|
-
once during 1 epoch.
|
|
851
|
-
omega_batch_size
|
|
852
|
-
An integer. The size of the batch of randomly selected points among
|
|
853
|
-
the `n` points.
|
|
854
|
-
omega_border_batch_size
|
|
855
|
-
An integer. The size of the batch of points randomly selected
|
|
856
|
-
among the `nb` points.
|
|
857
|
-
Can be `None` not to lose performance generating the border
|
|
858
|
-
batch if they are not used
|
|
859
|
-
temporal_batch_size
|
|
860
|
-
An integer. The size of the batch of randomly selected points among
|
|
861
|
-
the `nt` points.
|
|
862
|
-
dim
|
|
863
|
-
An integer. dimension of :math:`\Omega` domain
|
|
864
|
-
min_pts
|
|
865
|
-
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
866
|
-
in `n` dimension, this represents :math:`(x_{1, min}, x_{2,min}, ...,
|
|
867
|
-
x_{n, min})`
|
|
868
|
-
max_pts
|
|
869
|
-
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
870
|
-
in `n` dimension, this represents :math:`(x_{1, max}, x_{2,max}, ...,
|
|
871
|
-
x_{n,max})`
|
|
872
|
-
tmin
|
|
873
|
-
A float. The minimum value of the time domain to consider
|
|
874
|
-
tmax
|
|
875
|
-
A float. The maximum value of the time domain to consider
|
|
876
|
-
method
|
|
877
|
-
Either `grid` or `uniform`, default is `grid`.
|
|
878
|
-
The method that generates the `nt` time points. `grid` means
|
|
879
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
880
|
-
sampled points over the domain
|
|
881
|
-
rar_parameters
|
|
882
|
-
Default to None: do not use Residual Adaptative Resampling.
|
|
883
|
-
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
884
|
-
which we start the RAR sampling scheme (we first have a burn in
|
|
885
|
-
period). `update_every`: the number of gradient steps taken between
|
|
886
|
-
each appending of collocation points in the RAR algo.
|
|
887
|
-
`sample_size_omega`: the size of the sample from which we will select new
|
|
888
|
-
collocation points. `selected_sample_size_omega`: the number of selected
|
|
889
|
-
points from the sample to be added to the current collocation
|
|
890
|
-
points.
|
|
891
|
-
n_start
|
|
892
|
-
Defaults to None. The effective size of n used at start time.
|
|
893
|
-
This value must be
|
|
894
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
895
|
-
n_start = n and this is hidden from the user.
|
|
896
|
-
In RAR, n_start
|
|
897
|
-
then corresponds to the initial number of omega points we train the PINN.
|
|
898
|
-
nt_start
|
|
899
|
-
Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
|
|
900
|
-
for times collocation point. See also ``DataGeneratorODE``
|
|
901
|
-
documentation.
|
|
902
|
-
data_exists
|
|
903
|
-
Must be left to `False` when created by the user. Avoids the
|
|
904
|
-
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
905
|
-
time points at each pytree flattening and unflattening.
|
|
786
|
+
temporal_batch_size: Int = eqx.field(kw_only=True)
|
|
787
|
+
tmin: Float = eqx.field(kw_only=True)
|
|
788
|
+
tmax: Float = eqx.field(kw_only=True)
|
|
789
|
+
nt: Int = eqx.field(kw_only=True)
|
|
790
|
+
temporal_batch_size: Int = eqx.field(kw_only=True, static=True)
|
|
791
|
+
cartesian_product: Bool = eqx.field(kw_only=True, default=True, static=True)
|
|
792
|
+
nt_start: int = eqx.field(kw_only=True, default=None, static=True)
|
|
793
|
+
|
|
794
|
+
p_times: Array = eqx.field(init=False)
|
|
795
|
+
curr_time_idx: Int = eqx.field(init=False)
|
|
796
|
+
times: Array = eqx.field(init=False)
|
|
797
|
+
|
|
798
|
+
def __post_init__(self):
|
|
906
799
|
"""
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
800
|
+
Note that neither __init__ or __post_init__ are called when udating a
|
|
801
|
+
Module with eqx.tree_at!
|
|
802
|
+
"""
|
|
803
|
+
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
804
|
+
# class is not automatically called
|
|
805
|
+
|
|
806
|
+
if not self.cartesian_product:
|
|
807
|
+
if self.temporal_batch_size != self.omega_batch_size:
|
|
808
|
+
raise ValueError(
|
|
809
|
+
"If stacking is requested between the time and "
|
|
810
|
+
"inside batches of collocation points, self.temporal_batch_size "
|
|
811
|
+
"must then be equal to self.omega_batch_size"
|
|
812
|
+
)
|
|
813
|
+
if (
|
|
814
|
+
self.dim > 1
|
|
815
|
+
and self.omega_border_batch_size is not None
|
|
816
|
+
and self.temporal_batch_size != self.omega_border_batch_size
|
|
817
|
+
):
|
|
818
|
+
raise ValueError(
|
|
819
|
+
"If dim > 1 and stacking is requested between the time and "
|
|
820
|
+
"inside batches of collocation points, self.temporal_batch_size "
|
|
821
|
+
"must then be equal to self.omega_border_batch_size"
|
|
822
|
+
)
|
|
823
|
+
# Note if self.dim == 1:
|
|
824
|
+
# print(
|
|
825
|
+
# "Cartesian product is not requested but will be "
|
|
826
|
+
# "executed anyway since dim=1"
|
|
827
|
+
# )
|
|
925
828
|
|
|
926
829
|
# Set-up for timewise RAR (some quantity are already set-up by super())
|
|
927
830
|
(
|
|
@@ -929,46 +832,54 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
929
832
|
self.p_times,
|
|
930
833
|
_,
|
|
931
834
|
_,
|
|
932
|
-
) = _check_and_set_rar_parameters(rar_parameters,
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
835
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.nt_start)
|
|
836
|
+
|
|
837
|
+
self.curr_time_idx = jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
|
|
838
|
+
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
839
|
+
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
840
|
+
self.key, self.times = self.generate_time_data(self.key)
|
|
841
|
+
# see explaination in DataGeneratorODE for the key
|
|
842
|
+
|
|
843
|
+
def sample_in_time_domain(
|
|
844
|
+
self, key: Key, sample_size: Int = None
|
|
845
|
+
) -> Float[Array, "nt"]:
|
|
846
|
+
return jax.random.uniform(
|
|
847
|
+
key,
|
|
848
|
+
(self.nt if sample_size is None else sample_size,),
|
|
849
|
+
minval=self.tmin,
|
|
850
|
+
maxval=self.tmax,
|
|
851
|
+
)
|
|
946
852
|
|
|
947
|
-
def _get_time_operands(
|
|
853
|
+
def _get_time_operands(
|
|
854
|
+
self,
|
|
855
|
+
) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
|
|
948
856
|
return (
|
|
949
|
-
self.
|
|
857
|
+
self.key,
|
|
950
858
|
self.times,
|
|
951
859
|
self.curr_time_idx,
|
|
952
860
|
self.temporal_batch_size,
|
|
953
861
|
self.p_times,
|
|
954
862
|
)
|
|
955
863
|
|
|
956
|
-
def
|
|
957
|
-
|
|
864
|
+
def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
|
|
865
|
+
"""
|
|
958
866
|
Construct a complete set of `self.nt` time points according to the
|
|
959
|
-
specified `self.method
|
|
960
|
-
|
|
961
|
-
|
|
867
|
+
specified `self.method`
|
|
868
|
+
|
|
869
|
+
Note that self.times has always size self.nt and not self.nt_start, even
|
|
870
|
+
in RAR scheme, we must allocate all the collocation points
|
|
962
871
|
"""
|
|
872
|
+
key, subkey = jax.random.split(key, 2)
|
|
963
873
|
if self.method == "grid":
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
874
|
+
partial_times = (self.tmax - self.tmin) / self.nt
|
|
875
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)
|
|
876
|
+
if self.method == "uniform":
|
|
877
|
+
return key, self.sample_in_time_domain(subkey)
|
|
878
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
970
879
|
|
|
971
|
-
def temporal_batch(
|
|
880
|
+
def temporal_batch(
|
|
881
|
+
self,
|
|
882
|
+
) -> tuple["CubicMeshPDENonStatio", Float[Array, "temporal_batch_size"]]:
|
|
972
883
|
"""
|
|
973
884
|
Return a batch of time points. If all the batches have been seen, we
|
|
974
885
|
reshuffle them, otherwise we just return the next unseen batch.
|
|
@@ -979,233 +890,344 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
979
890
|
# Compute the effective number of used collocation points
|
|
980
891
|
if self.rar_parameters is not None:
|
|
981
892
|
nt_eff = (
|
|
982
|
-
self.
|
|
893
|
+
self.nt_start
|
|
983
894
|
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
|
|
984
895
|
)
|
|
985
896
|
else:
|
|
986
897
|
nt_eff = self.nt
|
|
987
898
|
|
|
988
|
-
(
|
|
989
|
-
|
|
899
|
+
new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
|
|
900
|
+
new = eqx.tree_at(
|
|
901
|
+
lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
|
|
990
902
|
)
|
|
991
903
|
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
self.times,
|
|
997
|
-
start_indices=(self.curr_time_idx,),
|
|
998
|
-
slice_sizes=(self.temporal_batch_size,),
|
|
904
|
+
return new, jax.lax.dynamic_slice(
|
|
905
|
+
new.times,
|
|
906
|
+
start_indices=(new.curr_time_idx,),
|
|
907
|
+
slice_sizes=(new.temporal_batch_size,),
|
|
999
908
|
)
|
|
1000
909
|
|
|
1001
|
-
def get_batch(self):
|
|
910
|
+
def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
|
|
1002
911
|
"""
|
|
1003
912
|
Generic method to return a batch. Here we call `self.inside_batch()`,
|
|
1004
913
|
`self.border_batch()` and `self.temporal_batch()`
|
|
1005
914
|
"""
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
915
|
+
new, x = self.inside_batch()
|
|
916
|
+
new, dx = new.border_batch()
|
|
917
|
+
new, t = new.temporal_batch()
|
|
918
|
+
t = t.reshape(new.temporal_batch_size, 1)
|
|
919
|
+
|
|
920
|
+
if new.cartesian_product:
|
|
921
|
+
t_x = make_cartesian_product(t, x)
|
|
922
|
+
else:
|
|
923
|
+
t_x = jnp.concatenate([t, x], axis=1)
|
|
924
|
+
|
|
925
|
+
if dx is not None:
|
|
926
|
+
t_ = t.reshape(new.temporal_batch_size, 1, 1)
|
|
927
|
+
t_ = jnp.repeat(t_, dx.shape[-1], axis=2)
|
|
928
|
+
if new.cartesian_product or new.dim == 1:
|
|
929
|
+
t_dx = make_cartesian_product(t_, dx)
|
|
930
|
+
else:
|
|
931
|
+
t_dx = jnp.concatenate([t_, dx], axis=1)
|
|
932
|
+
else:
|
|
933
|
+
t_dx = None
|
|
934
|
+
|
|
935
|
+
return new, PDENonStatioBatch(
|
|
936
|
+
times_x_inside_batch=t_x, times_x_border_batch=t_dx
|
|
1010
937
|
)
|
|
1011
938
|
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
939
|
+
|
|
940
|
+
class DataGeneratorObservations(eqx.Module):
|
|
941
|
+
r"""
|
|
942
|
+
Despite the class name, it is rather a dataloader from user provided
|
|
943
|
+
observations that will be used for the observations loss
|
|
944
|
+
|
|
945
|
+
Parameters
|
|
946
|
+
----------
|
|
947
|
+
key : Key
|
|
948
|
+
Jax random key to shuffle batches
|
|
949
|
+
obs_batch_size : Int
|
|
950
|
+
The size of the batch of randomly selected points among
|
|
951
|
+
the `n` points. `obs_batch_size` will be the same for all
|
|
952
|
+
elements of the return observation dict batch.
|
|
953
|
+
NOTE: no check is done BUT users should be careful that
|
|
954
|
+
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
955
|
+
`omega_batch_size` or the product of both. In the first case, the
|
|
956
|
+
present DataGeneratorObservations instance complements an ODEBatch,
|
|
957
|
+
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
958
|
+
= False). In the second case, `obs_batch_size` =
|
|
959
|
+
`temporal_batch_size * omega_batch_size` if the present
|
|
960
|
+
DataGeneratorParameter complements a PDENonStatioBatch
|
|
961
|
+
with self.cartesian_product = True
|
|
962
|
+
observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
|
|
963
|
+
Observed values corresponding to the input of the PINN
|
|
964
|
+
(eg. the time at which we recorded the observations). The first
|
|
965
|
+
dimension must corresponds to the number of observed_values.
|
|
966
|
+
The second dimension depends on the input dimension of the PINN,
|
|
967
|
+
that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
|
|
968
|
+
for non-stationnary PDE.
|
|
969
|
+
observed_values : Float[Array, "n_obs, nb_pinn_out"]
|
|
970
|
+
Observed values that the PINN should learn to fit. The first
|
|
971
|
+
dimension must be aligned with observed_pinn_in.
|
|
972
|
+
observed_eq_params : Dict[str, Float[Array, "n_obs 1"]], default={}
|
|
973
|
+
A dict with keys corresponding to
|
|
974
|
+
the parameter name. The keys must match the keys in
|
|
975
|
+
`params["eq_params"]`. The values are jnp.array with 2 dimensions
|
|
976
|
+
with values corresponding to the parameter value for which we also
|
|
977
|
+
have observed_pinn_in and observed_values. Hence the first
|
|
978
|
+
dimension must be aligned with observed_pinn_in and observed_values.
|
|
979
|
+
Optional argument.
|
|
980
|
+
sharding_device : jax.sharding.Sharding, default=None
|
|
981
|
+
Default None. An optional sharding object to constraint the storage
|
|
982
|
+
of observed inputs, values and parameters. Typically, a
|
|
983
|
+
SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
|
|
984
|
+
datasets of observations. Note that computations for **batches**
|
|
985
|
+
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
986
|
+
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
987
|
+
arguments of `jinns.solve()`. Read the docs for more info.
|
|
988
|
+
"""
|
|
989
|
+
|
|
990
|
+
key: Key
|
|
991
|
+
obs_batch_size: Int = eqx.field(static=True)
|
|
992
|
+
observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
|
|
993
|
+
observed_values: Float[Array, "n_obs nb_pinn_out"]
|
|
994
|
+
observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
|
|
995
|
+
static=True, default_factory=lambda: {}
|
|
996
|
+
)
|
|
997
|
+
sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
|
|
998
|
+
|
|
999
|
+
n: Int = eqx.field(init=False)
|
|
1000
|
+
curr_idx: Int = eqx.field(init=False)
|
|
1001
|
+
indices: Array = eqx.field(init=False)
|
|
1002
|
+
|
|
1003
|
+
def __post_init__(self):
|
|
1004
|
+
if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
|
|
1005
|
+
raise ValueError(
|
|
1006
|
+
"self.observed_pinn_in and self.observed_values must have same first axis"
|
|
1007
|
+
)
|
|
1008
|
+
for _, v in self.observed_eq_params.items():
|
|
1009
|
+
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
1010
|
+
raise ValueError(
|
|
1011
|
+
"self.observed_pinn_in and the values of"
|
|
1012
|
+
" self.observed_eq_params must have the same first axis"
|
|
1013
|
+
)
|
|
1014
|
+
if len(self.observed_pinn_in.shape) == 1:
|
|
1015
|
+
self.observed_pinn_in = self.observed_pinn_in[:, None]
|
|
1016
|
+
if len(self.observed_pinn_in.shape) > 2:
|
|
1017
|
+
raise ValueError("self.observed_pinn_in must have 2 dimensions")
|
|
1018
|
+
if len(self.observed_values.shape) == 1:
|
|
1019
|
+
self.observed_values = self.observed_values[:, None]
|
|
1020
|
+
if len(self.observed_values.shape) > 2:
|
|
1021
|
+
raise ValueError("self.observed_values must have 2 dimensions")
|
|
1022
|
+
for k, v in self.observed_eq_params.items():
|
|
1023
|
+
if len(v.shape) == 1:
|
|
1024
|
+
self.observed_eq_params[k] = v[:, None]
|
|
1025
|
+
if len(v.shape) > 2:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
"Each value of observed_eq_params must have 2 dimensions"
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
self.n = self.observed_pinn_in.shape[0]
|
|
1031
|
+
|
|
1032
|
+
if self.sharding_device is not None:
|
|
1033
|
+
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
1034
|
+
self.observed_pinn_in, self.sharding_device
|
|
1035
|
+
)
|
|
1036
|
+
self.observed_values = jax.lax.with_sharding_constraint(
|
|
1037
|
+
self.observed_values, self.sharding_device
|
|
1038
|
+
)
|
|
1039
|
+
self.observed_eq_params = jax.lax.with_sharding_constraint(
|
|
1040
|
+
self.observed_eq_params, self.sharding_device
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
self.curr_idx = jnp.iinfo(jnp.int32).max - self.obs_batch_size - 1
|
|
1044
|
+
# For speed and to avoid duplicating data what is really
|
|
1045
|
+
# shuffled is a vector of indices
|
|
1046
|
+
if self.sharding_device is not None:
|
|
1047
|
+
self.indices = jax.lax.with_sharding_constraint(
|
|
1048
|
+
jnp.arange(self.n), self.sharding_device
|
|
1049
|
+
)
|
|
1050
|
+
else:
|
|
1051
|
+
self.indices = jnp.arange(self.n)
|
|
1052
|
+
|
|
1053
|
+
# recall post_init is the only place with _init_ where we can set
|
|
1054
|
+
# self attribute in a in-place way
|
|
1055
|
+
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
1056
|
+
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
1057
|
+
|
|
1058
|
+
def _get_operands(self) -> tuple[Key, Int[Array, "n"], Int, Int, None]:
|
|
1059
|
+
return (
|
|
1060
|
+
self.key,
|
|
1061
|
+
self.indices,
|
|
1062
|
+
self.curr_idx,
|
|
1063
|
+
self.obs_batch_size,
|
|
1064
|
+
None,
|
|
1029
1065
|
)
|
|
1030
|
-
aux_data = {
|
|
1031
|
-
k: vars(self)[k]
|
|
1032
|
-
for k in [
|
|
1033
|
-
"n",
|
|
1034
|
-
"nb",
|
|
1035
|
-
"nt",
|
|
1036
|
-
"omega_batch_size",
|
|
1037
|
-
"omega_border_batch_size",
|
|
1038
|
-
"temporal_batch_size",
|
|
1039
|
-
"method",
|
|
1040
|
-
"dim",
|
|
1041
|
-
"rar_parameters",
|
|
1042
|
-
"n_start",
|
|
1043
|
-
"nt_start",
|
|
1044
|
-
]
|
|
1045
|
-
}
|
|
1046
|
-
return (children, aux_data)
|
|
1047
1066
|
|
|
1048
|
-
|
|
1049
|
-
|
|
1067
|
+
def obs_batch(
|
|
1068
|
+
self,
|
|
1069
|
+
) -> tuple[
|
|
1070
|
+
"DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
|
|
1071
|
+
]:
|
|
1050
1072
|
"""
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1073
|
+
Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
|
|
1074
|
+
inputs), (obs, a mini batch of corresponding observations), (eq_params,
|
|
1075
|
+
a dictionary with entry names found in `params["eq_params"]` and values
|
|
1076
|
+
giving the correspond parameter value for the couple
|
|
1077
|
+
(input, observation) mentioned before).
|
|
1078
|
+
It can also be a dictionary of dictionaries as described above if
|
|
1079
|
+
observed_pinn_in, observed_values, etc. are dictionaries with keys
|
|
1080
|
+
representing the PINNs.
|
|
1055
1081
|
"""
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
omega_border,
|
|
1060
|
-
times,
|
|
1061
|
-
curr_omega_idx,
|
|
1062
|
-
curr_omega_border_idx,
|
|
1063
|
-
curr_time_idx,
|
|
1064
|
-
min_pts,
|
|
1065
|
-
max_pts,
|
|
1066
|
-
tmin,
|
|
1067
|
-
tmax,
|
|
1068
|
-
p_times,
|
|
1069
|
-
p_omega,
|
|
1070
|
-
rar_iter_from_last_sampling,
|
|
1071
|
-
rar_iter_nb,
|
|
1072
|
-
) = children
|
|
1073
|
-
obj = cls(
|
|
1074
|
-
key=key,
|
|
1075
|
-
data_exists=True,
|
|
1076
|
-
min_pts=min_pts,
|
|
1077
|
-
max_pts=max_pts,
|
|
1078
|
-
tmin=tmin,
|
|
1079
|
-
tmax=tmax,
|
|
1080
|
-
**aux_data,
|
|
1082
|
+
|
|
1083
|
+
new_attributes = _reset_or_increment(
|
|
1084
|
+
self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
|
|
1081
1085
|
)
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1086
|
+
new = eqx.tree_at(
|
|
1087
|
+
lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
minib_indices = jax.lax.dynamic_slice(
|
|
1091
|
+
new.indices,
|
|
1092
|
+
start_indices=(new.curr_idx,),
|
|
1093
|
+
slice_sizes=(new.obs_batch_size,),
|
|
1094
|
+
)
|
|
1095
|
+
|
|
1096
|
+
obs_batch = {
|
|
1097
|
+
"pinn_in": jnp.take(
|
|
1098
|
+
new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
|
|
1099
|
+
),
|
|
1100
|
+
"val": jnp.take(
|
|
1101
|
+
new.observed_values, minib_indices, unique_indices=True, axis=0
|
|
1102
|
+
),
|
|
1103
|
+
"eq_params": jax.tree_util.tree_map(
|
|
1104
|
+
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
1105
|
+
new.observed_eq_params,
|
|
1106
|
+
),
|
|
1107
|
+
}
|
|
1108
|
+
return new, obs_batch
|
|
1101
1109
|
|
|
1102
|
-
def
|
|
1110
|
+
def get_batch(
|
|
1103
1111
|
self,
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
method
|
|
1109
|
-
user_data=None,
|
|
1110
|
-
data_exists=False,
|
|
1111
|
-
):
|
|
1112
|
-
r"""
|
|
1113
|
-
Parameters
|
|
1114
|
-
----------
|
|
1115
|
-
key
|
|
1116
|
-
Jax random key to sample new time points and to shuffle batches
|
|
1117
|
-
or a dict of Jax random keys with key entries from param_ranges
|
|
1118
|
-
n
|
|
1119
|
-
An integer. The number of total points that will be divided in
|
|
1120
|
-
batches. Batches are made so that each data point is seen only
|
|
1121
|
-
once during 1 epoch.
|
|
1122
|
-
param_batch_size
|
|
1123
|
-
An integer. The size of the batch of randomly selected points among
|
|
1124
|
-
the `n` points. `param_batch_size` will be the same for all the
|
|
1125
|
-
additional batch(es) of parameter(s). `param_batch_size` must be
|
|
1126
|
-
equal to `temporal_batch_size` or `omega_batch_size` or the product
|
|
1127
|
-
of both whether the present DataGeneratorParameter instance
|
|
1128
|
-
complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
|
|
1129
|
-
respectively.
|
|
1130
|
-
param_ranges
|
|
1131
|
-
A dict. A dict of tuples (min, max), which
|
|
1132
|
-
reprensents the range of real numbers where to sample batches (of
|
|
1133
|
-
length `param_batch_size` among `n` points).
|
|
1134
|
-
The key corresponds to the parameter name. The keys must match the
|
|
1135
|
-
keys in `params["eq_params"]`.
|
|
1136
|
-
By providing several entries in this dictionary we can sample
|
|
1137
|
-
an arbitrary number of parameters.
|
|
1138
|
-
__Note__ that we currently only support unidimensional parameters
|
|
1139
|
-
method
|
|
1140
|
-
Either `grid` or `uniform`, default is `grid`. `grid` means
|
|
1141
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
1142
|
-
sampled points over the domain
|
|
1143
|
-
data_exists
|
|
1144
|
-
Must be left to `False` when created by the user. Avoids the
|
|
1145
|
-
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
1146
|
-
time points at each pytree flattening and unflattening.
|
|
1147
|
-
user_data
|
|
1148
|
-
A dictionary containing user-provided data for parameters.
|
|
1149
|
-
As for `param_ranges`, the key corresponds to the parameter name,
|
|
1150
|
-
the keys must match the keys in `params["eq_params"]` and only
|
|
1151
|
-
unidimensional arrays are supported. Therefore, the jnp arrays
|
|
1152
|
-
found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
1153
|
-
Note that if the same key appears in `param_ranges` and `user_data`
|
|
1154
|
-
priority goes for the content in `user_data`.
|
|
1155
|
-
Defaults to None.
|
|
1112
|
+
) -> tuple[
|
|
1113
|
+
"DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
|
|
1114
|
+
]:
|
|
1115
|
+
"""
|
|
1116
|
+
Generic method to return a batch
|
|
1156
1117
|
"""
|
|
1157
|
-
self.
|
|
1158
|
-
|
|
1118
|
+
return self.obs_batch()
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
class DataGeneratorParameter(eqx.Module):
|
|
1122
|
+
r"""
|
|
1123
|
+
A data generator for additional unidimensional parameter(s)
|
|
1159
1124
|
|
|
1160
|
-
|
|
1125
|
+
Parameters
|
|
1126
|
+
----------
|
|
1127
|
+
keys : Key | Dict[str, Key]
|
|
1128
|
+
Jax random key to sample new time points and to shuffle batches
|
|
1129
|
+
or a dict of Jax random keys with key entries from param_ranges
|
|
1130
|
+
n : Int
|
|
1131
|
+
The number of total points that will be divided in
|
|
1132
|
+
batches. Batches are made so that each data point is seen only
|
|
1133
|
+
once during 1 epoch.
|
|
1134
|
+
param_batch_size : Int
|
|
1135
|
+
The size of the batch of randomly selected points among
|
|
1136
|
+
the `n` points. `param_batch_size` will be the same for all
|
|
1137
|
+
additional batch of parameter.
|
|
1138
|
+
NOTE: no check is done BUT users should be careful that
|
|
1139
|
+
`param_batch_size` must be equal to `temporal_batch_size` or
|
|
1140
|
+
`omega_batch_size` or the product of both. In the first case, the
|
|
1141
|
+
present DataGeneratorParameter instance complements an ODEBatch, a
|
|
1142
|
+
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1143
|
+
= False). In the second case, `param_batch_size` =
|
|
1144
|
+
`temporal_batch_size * omega_batch_size` if the present
|
|
1145
|
+
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1146
|
+
with self.cartesian_product = True
|
|
1147
|
+
param_ranges : Dict[str, tuple[Float, Float] | None, default={}
|
|
1148
|
+
A dict. A dict of tuples (min, max), which
|
|
1149
|
+
reprensents the range of real numbers where to sample batches (of
|
|
1150
|
+
length `param_batch_size` among `n` points).
|
|
1151
|
+
The key corresponds to the parameter name. The keys must match the
|
|
1152
|
+
keys in `params["eq_params"]`.
|
|
1153
|
+
By providing several entries in this dictionary we can sample
|
|
1154
|
+
an arbitrary number of parameters.
|
|
1155
|
+
**Note** that we currently only support unidimensional parameters.
|
|
1156
|
+
This argument can be done if we only use `user_data`.
|
|
1157
|
+
method : str, default="uniform"
|
|
1158
|
+
Either `grid` or `uniform`, default is `uniform`. `grid` means
|
|
1159
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
1160
|
+
sampled points over the domain
|
|
1161
|
+
user_data : Dict[str, Float[Array, "n"]] | None, default={}
|
|
1162
|
+
A dictionary containing user-provided data for parameters.
|
|
1163
|
+
As for `param_ranges`, the key corresponds to the parameter name,
|
|
1164
|
+
the keys must match the keys in `params["eq_params"]` and only
|
|
1165
|
+
unidimensional arrays are supported. Therefore, the jnp arrays
|
|
1166
|
+
found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
1167
|
+
Note that if the same key appears in `param_ranges` and `user_data`
|
|
1168
|
+
priority goes for the content in `user_data`.
|
|
1169
|
+
Defaults to None.
|
|
1170
|
+
"""
|
|
1171
|
+
|
|
1172
|
+
keys: Key | Dict[str, Key]
|
|
1173
|
+
n: Int
|
|
1174
|
+
param_batch_size: Int = eqx.field(static=True)
|
|
1175
|
+
param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
|
|
1176
|
+
static=True, default_factory=lambda: {}
|
|
1177
|
+
)
|
|
1178
|
+
method: str = eqx.field(static=True, default="uniform")
|
|
1179
|
+
user_data: Dict[str, Float[Array, "n"]] | None = eqx.field(
|
|
1180
|
+
static=True, default_factory=lambda: {}
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
curr_param_idx: Dict[str, Int] = eqx.field(init=False)
|
|
1184
|
+
param_n_samples: Dict[str, Array] = eqx.field(init=False)
|
|
1185
|
+
|
|
1186
|
+
def __post_init__(self):
|
|
1187
|
+
if self.user_data is None:
|
|
1188
|
+
self.user_data = {}
|
|
1189
|
+
if self.param_ranges is None:
|
|
1190
|
+
self.param_ranges = {}
|
|
1191
|
+
if self.n < self.param_batch_size:
|
|
1161
1192
|
raise ValueError(
|
|
1162
|
-
f"Number of data points ({n}) is smaller than the"
|
|
1163
|
-
f"number of batch points ({param_batch_size})."
|
|
1193
|
+
f"Number of data points ({self.n}) is smaller than the"
|
|
1194
|
+
f"number of batch points ({self.param_batch_size})."
|
|
1195
|
+
)
|
|
1196
|
+
if not isinstance(self.keys, dict):
|
|
1197
|
+
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1198
|
+
self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
|
|
1199
|
+
|
|
1200
|
+
self.curr_param_idx = {}
|
|
1201
|
+
for k in self.keys.keys():
|
|
1202
|
+
self.curr_param_idx[k] = (
|
|
1203
|
+
jnp.iinfo(jnp.int32).max - self.param_batch_size - 1
|
|
1164
1204
|
)
|
|
1165
1205
|
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
self.n = n
|
|
1176
|
-
self.param_batch_size = param_batch_size
|
|
1177
|
-
self.param_ranges = param_ranges
|
|
1178
|
-
self.user_data = user_data
|
|
1179
|
-
|
|
1180
|
-
if not self.data_exists:
|
|
1181
|
-
self.generate_data()
|
|
1182
|
-
# The previous call to self.generate_data() has created
|
|
1183
|
-
# the dict self.param_n_samples and then we will only use this one
|
|
1184
|
-
# because it has merged the scattered data between `user_data` and
|
|
1185
|
-
# `param_ranges`
|
|
1186
|
-
self.curr_param_idx = {}
|
|
1187
|
-
for k in self.param_n_samples.keys():
|
|
1188
|
-
self.curr_param_idx[k] = 0
|
|
1189
|
-
(
|
|
1190
|
-
self._keys[k],
|
|
1191
|
-
self.param_n_samples[k],
|
|
1192
|
-
_,
|
|
1193
|
-
) = _reset_batch_idx_and_permute(self._get_param_operands(k))
|
|
1194
|
-
|
|
1195
|
-
def generate_data(self):
|
|
1206
|
+
# The call to self.generate_data() creates
|
|
1207
|
+
# the dict self.param_n_samples and then we will only use this one
|
|
1208
|
+
# because it merges the scattered data between `user_data` and
|
|
1209
|
+
# `param_ranges`
|
|
1210
|
+
self.keys, self.param_n_samples = self.generate_data(self.keys)
|
|
1211
|
+
|
|
1212
|
+
def generate_data(
|
|
1213
|
+
self, keys: Dict[str, Key]
|
|
1214
|
+
) -> tuple[Dict[str, Key], Dict[str, Float[Array, "n"]]]:
|
|
1196
1215
|
"""
|
|
1197
1216
|
Generate parameter samples, either through generation
|
|
1198
1217
|
or using user-provided data.
|
|
1199
1218
|
"""
|
|
1200
|
-
|
|
1219
|
+
param_n_samples = {}
|
|
1201
1220
|
|
|
1202
1221
|
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1203
1222
|
for k in all_keys:
|
|
1204
|
-
if
|
|
1223
|
+
if (
|
|
1224
|
+
self.user_data
|
|
1225
|
+
and k in self.user_data.keys() # pylint: disable=no-member
|
|
1226
|
+
):
|
|
1205
1227
|
if self.user_data[k].shape == (self.n, 1):
|
|
1206
|
-
|
|
1228
|
+
param_n_samples[k] = self.user_data[k]
|
|
1207
1229
|
if self.user_data[k].shape == (self.n,):
|
|
1208
|
-
|
|
1230
|
+
param_n_samples[k] = self.user_data[k][:, None]
|
|
1209
1231
|
else:
|
|
1210
1232
|
raise ValueError(
|
|
1211
1233
|
"Wrong shape for user provided parameters"
|
|
@@ -1214,23 +1236,25 @@ class DataGeneratorParameter:
|
|
|
1214
1236
|
else:
|
|
1215
1237
|
if self.method == "grid":
|
|
1216
1238
|
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1217
|
-
|
|
1239
|
+
partial = (xmax - xmin) / self.n
|
|
1218
1240
|
# shape (n, 1)
|
|
1219
|
-
|
|
1220
|
-
:, None
|
|
1221
|
-
]
|
|
1241
|
+
param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
|
|
1222
1242
|
elif self.method == "uniform":
|
|
1223
1243
|
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1224
|
-
|
|
1225
|
-
|
|
1244
|
+
keys[k], subkey = jax.random.split(keys[k], 2)
|
|
1245
|
+
param_n_samples[k] = jax.random.uniform(
|
|
1226
1246
|
subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
|
|
1227
1247
|
)
|
|
1228
1248
|
else:
|
|
1229
1249
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
1230
1250
|
|
|
1231
|
-
|
|
1251
|
+
return keys, param_n_samples
|
|
1252
|
+
|
|
1253
|
+
def _get_param_operands(
|
|
1254
|
+
self, k: str
|
|
1255
|
+
) -> tuple[Key, Float[Array, "n"], Int, Int, None]:
|
|
1232
1256
|
return (
|
|
1233
|
-
self.
|
|
1257
|
+
self.keys[k],
|
|
1234
1258
|
self.param_n_samples[k],
|
|
1235
1259
|
self.curr_param_idx[k],
|
|
1236
1260
|
self.param_batch_size,
|
|
@@ -1255,26 +1279,28 @@ class DataGeneratorParameter:
|
|
|
1255
1279
|
_reset_or_increment_wrapper,
|
|
1256
1280
|
self.param_n_samples,
|
|
1257
1281
|
self.curr_param_idx,
|
|
1258
|
-
self.
|
|
1282
|
+
self.keys,
|
|
1259
1283
|
)
|
|
1260
1284
|
# we must transpose the pytrees because keys are merged in res
|
|
1261
1285
|
# https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
|
|
1262
|
-
(
|
|
1263
|
-
self.
|
|
1264
|
-
self.param_n_samples,
|
|
1265
|
-
self.curr_param_idx,
|
|
1266
|
-
) = jax.tree_util.tree_transpose(
|
|
1267
|
-
jax.tree_util.tree_structure(self._keys),
|
|
1286
|
+
new_attributes = jax.tree_util.tree_transpose(
|
|
1287
|
+
jax.tree_util.tree_structure(self.keys),
|
|
1268
1288
|
jax.tree_util.tree_structure([0, 0, 0]),
|
|
1269
1289
|
res,
|
|
1270
1290
|
)
|
|
1271
1291
|
|
|
1272
|
-
|
|
1292
|
+
new = eqx.tree_at(
|
|
1293
|
+
lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
|
|
1294
|
+
self,
|
|
1295
|
+
new_attributes,
|
|
1296
|
+
)
|
|
1297
|
+
|
|
1298
|
+
return new, jax.tree_util.tree_map(
|
|
1273
1299
|
lambda p, q: jax.lax.dynamic_slice(
|
|
1274
|
-
p, start_indices=(q, 0), slice_sizes=(
|
|
1300
|
+
p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
|
|
1275
1301
|
),
|
|
1276
|
-
|
|
1277
|
-
|
|
1302
|
+
new.param_n_samples,
|
|
1303
|
+
new.curr_param_idx,
|
|
1278
1304
|
)
|
|
1279
1305
|
|
|
1280
1306
|
def get_batch(self):
|
|
@@ -1283,246 +1309,9 @@ class DataGeneratorParameter:
|
|
|
1283
1309
|
"""
|
|
1284
1310
|
return self.param_batch()
|
|
1285
1311
|
|
|
1286
|
-
def tree_flatten(self):
|
|
1287
|
-
children = (
|
|
1288
|
-
self._keys,
|
|
1289
|
-
self.param_n_samples,
|
|
1290
|
-
self.curr_param_idx,
|
|
1291
|
-
)
|
|
1292
|
-
aux_data = {
|
|
1293
|
-
k: vars(self)[k]
|
|
1294
|
-
for k in ["n", "param_batch_size", "method", "param_ranges", "user_data"]
|
|
1295
|
-
}
|
|
1296
|
-
return (children, aux_data)
|
|
1297
|
-
|
|
1298
|
-
@classmethod
|
|
1299
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1300
|
-
(
|
|
1301
|
-
keys,
|
|
1302
|
-
param_n_samples,
|
|
1303
|
-
curr_param_idx,
|
|
1304
|
-
) = children
|
|
1305
|
-
obj = cls(
|
|
1306
|
-
key=keys,
|
|
1307
|
-
data_exists=True,
|
|
1308
|
-
**aux_data,
|
|
1309
|
-
)
|
|
1310
|
-
obj.param_n_samples = param_n_samples
|
|
1311
|
-
obj.curr_param_idx = curr_param_idx
|
|
1312
|
-
return obj
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
@register_pytree_node_class
|
|
1316
|
-
class DataGeneratorObservations:
|
|
1317
|
-
"""
|
|
1318
|
-
Despite the class name, it is rather a dataloader from user provided
|
|
1319
|
-
observations that will be used for the observations loss
|
|
1320
|
-
"""
|
|
1321
|
-
|
|
1322
|
-
def __init__(
|
|
1323
|
-
self,
|
|
1324
|
-
key,
|
|
1325
|
-
obs_batch_size,
|
|
1326
|
-
observed_pinn_in,
|
|
1327
|
-
observed_values,
|
|
1328
|
-
observed_eq_params=None,
|
|
1329
|
-
data_exists=False,
|
|
1330
|
-
sharding_device=None,
|
|
1331
|
-
):
|
|
1332
|
-
r"""
|
|
1333
|
-
Parameters
|
|
1334
|
-
----------
|
|
1335
|
-
key
|
|
1336
|
-
Jax random key to sample new time points and to shuffle batches
|
|
1337
|
-
obs_batch_size
|
|
1338
|
-
An integer. The size of the batch of randomly selected observations
|
|
1339
|
-
`obs_batch_size` will be the same for all the
|
|
1340
|
-
elements of the obs dict. `obs_batch_size` must be
|
|
1341
|
-
equal to `temporal_batch_size` or `omega_batch_size` or the product
|
|
1342
|
-
of both whether the present DataGeneratorParameter instance
|
|
1343
|
-
complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
|
|
1344
|
-
respectively.
|
|
1345
|
-
observed_pinn_in
|
|
1346
|
-
A jnp.array with 2 dimensions.
|
|
1347
|
-
Observed values corresponding to the input of the PINN
|
|
1348
|
-
(eg. the time at which we recorded the observations). The first
|
|
1349
|
-
dimension must corresponds to the number of observed_values and
|
|
1350
|
-
observed_eq_params. The second dimension depends on the input dimension of the PINN, that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1` for non-stationnary PDE.
|
|
1351
|
-
observed_values
|
|
1352
|
-
A jnp.array with 2 dimensions.
|
|
1353
|
-
Observed values that the PINN should learn to fit. The first dimension must be aligned with observed_pinn_in and the values of observed_eq_params.
|
|
1354
|
-
observed_eq_params
|
|
1355
|
-
Optional. Default is None. A dict with keys corresponding to the
|
|
1356
|
-
parameter name. The keys must match the keys in
|
|
1357
|
-
`params["eq_params"]`. The values are jnp.array with 2 dimensions
|
|
1358
|
-
with values corresponding to the parameter value for which we also
|
|
1359
|
-
have observed_pinn_in and observed_values. Hence the first
|
|
1360
|
-
dimension must be aligned with observed_pinn_in and observed_values.
|
|
1361
|
-
data_exists
|
|
1362
|
-
Must be left to `False` when created by the user. Avoids the
|
|
1363
|
-
resetting of curr_idx at each pytree flattening and unflattening.
|
|
1364
|
-
sharding_device
|
|
1365
|
-
Default None. An optional sharding object to constraint the storage
|
|
1366
|
-
of observed inputs, values and parameters. Typically, a
|
|
1367
|
-
SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
|
|
1368
|
-
datasets of observations. Note that computations for **batches**
|
|
1369
|
-
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
1370
|
-
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
1371
|
-
arguments of `jinns.solve()`. Read the docs for more info.
|
|
1372
|
-
|
|
1373
|
-
"""
|
|
1374
|
-
if observed_eq_params is None:
|
|
1375
|
-
observed_eq_params = {}
|
|
1376
|
-
|
|
1377
|
-
if not data_exists:
|
|
1378
|
-
self.observed_eq_params = observed_eq_params.copy()
|
|
1379
|
-
else:
|
|
1380
|
-
# avoid copying when in flatten/unflatten
|
|
1381
|
-
self.observed_eq_params = observed_eq_params
|
|
1382
|
-
|
|
1383
|
-
if observed_pinn_in.shape[0] != observed_values.shape[0]:
|
|
1384
|
-
raise ValueError(
|
|
1385
|
-
"observed_pinn_in and observed_values must have same first axis"
|
|
1386
|
-
)
|
|
1387
|
-
for _, v in self.observed_eq_params.items():
|
|
1388
|
-
if v.shape[0] != observed_pinn_in.shape[0]:
|
|
1389
|
-
raise ValueError(
|
|
1390
|
-
"observed_pinn_in and the values of"
|
|
1391
|
-
" observed_eq_params must have the same first axis"
|
|
1392
|
-
)
|
|
1393
|
-
if len(observed_pinn_in.shape) == 1:
|
|
1394
|
-
observed_pinn_in = observed_pinn_in[:, None]
|
|
1395
|
-
if len(observed_pinn_in.shape) > 2:
|
|
1396
|
-
raise ValueError("observed_pinn_in must have 2 dimensions")
|
|
1397
|
-
if len(observed_values.shape) == 1:
|
|
1398
|
-
observed_values = observed_values[:, None]
|
|
1399
|
-
if len(observed_values.shape) > 2:
|
|
1400
|
-
raise ValueError("observed_values must have 2 dimensions")
|
|
1401
|
-
for k, v in self.observed_eq_params.items():
|
|
1402
|
-
if len(v.shape) == 1:
|
|
1403
|
-
self.observed_eq_params[k] = v[:, None]
|
|
1404
|
-
if len(v.shape) > 2:
|
|
1405
|
-
raise ValueError(
|
|
1406
|
-
"Each value of observed_eq_params must have 2 dimensions"
|
|
1407
|
-
)
|
|
1408
|
-
|
|
1409
|
-
self.n = observed_pinn_in.shape[0]
|
|
1410
|
-
self._key = key
|
|
1411
|
-
self.obs_batch_size = obs_batch_size
|
|
1412
|
-
|
|
1413
|
-
self.data_exists = data_exists
|
|
1414
|
-
if not self.data_exists and sharding_device is not None:
|
|
1415
|
-
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
1416
|
-
observed_pinn_in, sharding_device
|
|
1417
|
-
)
|
|
1418
|
-
self.observed_values = jax.lax.with_sharding_constraint(
|
|
1419
|
-
observed_values, sharding_device
|
|
1420
|
-
)
|
|
1421
|
-
self.observed_eq_params = jax.lax.with_sharding_constraint(
|
|
1422
|
-
self.observed_eq_params, sharding_device
|
|
1423
|
-
)
|
|
1424
|
-
else:
|
|
1425
|
-
self.observed_pinn_in = observed_pinn_in
|
|
1426
|
-
self.observed_values = observed_values
|
|
1427
|
-
|
|
1428
|
-
if not self.data_exists:
|
|
1429
|
-
self.curr_idx = 0
|
|
1430
|
-
# NOTE for speed and to avoid duplicating data what is really
|
|
1431
|
-
# shuffled is a vector of indices
|
|
1432
|
-
indices = jnp.arange(self.n)
|
|
1433
|
-
if sharding_device is not None:
|
|
1434
|
-
self.indices = jax.lax.with_sharding_constraint(
|
|
1435
|
-
indices, sharding_device
|
|
1436
|
-
)
|
|
1437
|
-
else:
|
|
1438
|
-
self.indices = indices
|
|
1439
|
-
self._key, self.indices, _ = _reset_batch_idx_and_permute(
|
|
1440
|
-
self._get_operands()
|
|
1441
|
-
)
|
|
1442
|
-
|
|
1443
|
-
def _get_operands(self):
|
|
1444
|
-
return (
|
|
1445
|
-
self._key,
|
|
1446
|
-
self.indices,
|
|
1447
|
-
self.curr_idx,
|
|
1448
|
-
self.obs_batch_size,
|
|
1449
|
-
None,
|
|
1450
|
-
)
|
|
1451
1312
|
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
|
|
1455
|
-
inputs), (obs, a mini batch of corresponding observations), (eq_params,
|
|
1456
|
-
a dictionary with entry names found in `params["eq_params"]` and values
|
|
1457
|
-
giving the correspond parameter value for the couple
|
|
1458
|
-
(input, observation) mentioned before).
|
|
1459
|
-
It can also be a dictionary of dictionaries as described above if
|
|
1460
|
-
observed_pinn_in, observed_values, etc. are dictionaries with keys
|
|
1461
|
-
representing the PINNs.
|
|
1462
|
-
"""
|
|
1463
|
-
|
|
1464
|
-
(self._key, self.indices, self.curr_idx) = _reset_or_increment(
|
|
1465
|
-
self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
|
|
1466
|
-
)
|
|
1467
|
-
|
|
1468
|
-
minib_indices = jax.lax.dynamic_slice(
|
|
1469
|
-
self.indices,
|
|
1470
|
-
start_indices=(self.curr_idx,),
|
|
1471
|
-
slice_sizes=(self.obs_batch_size,),
|
|
1472
|
-
)
|
|
1473
|
-
|
|
1474
|
-
obs_batch = {
|
|
1475
|
-
"pinn_in": jnp.take(
|
|
1476
|
-
self.observed_pinn_in, minib_indices, unique_indices=True, axis=0
|
|
1477
|
-
),
|
|
1478
|
-
"val": jnp.take(
|
|
1479
|
-
self.observed_values, minib_indices, unique_indices=True, axis=0
|
|
1480
|
-
),
|
|
1481
|
-
"eq_params": jax.tree_util.tree_map(
|
|
1482
|
-
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
1483
|
-
self.observed_eq_params,
|
|
1484
|
-
),
|
|
1485
|
-
}
|
|
1486
|
-
return obs_batch
|
|
1487
|
-
|
|
1488
|
-
def get_batch(self):
|
|
1489
|
-
"""
|
|
1490
|
-
Generic method to return a batch
|
|
1491
|
-
"""
|
|
1492
|
-
return self.obs_batch()
|
|
1493
|
-
|
|
1494
|
-
def tree_flatten(self):
|
|
1495
|
-
children = (self._key, self.curr_idx, self.indices)
|
|
1496
|
-
aux_data = {
|
|
1497
|
-
k: vars(self)[k]
|
|
1498
|
-
for k in [
|
|
1499
|
-
"obs_batch_size",
|
|
1500
|
-
"observed_pinn_in",
|
|
1501
|
-
"observed_values",
|
|
1502
|
-
"observed_eq_params",
|
|
1503
|
-
]
|
|
1504
|
-
}
|
|
1505
|
-
return (children, aux_data)
|
|
1506
|
-
|
|
1507
|
-
@classmethod
|
|
1508
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1509
|
-
(key, curr_idx, indices) = children
|
|
1510
|
-
obj = cls(
|
|
1511
|
-
key=key,
|
|
1512
|
-
data_exists=True,
|
|
1513
|
-
obs_batch_size=aux_data["obs_batch_size"],
|
|
1514
|
-
observed_pinn_in=aux_data["observed_pinn_in"],
|
|
1515
|
-
observed_values=aux_data["observed_values"],
|
|
1516
|
-
observed_eq_params=aux_data["observed_eq_params"],
|
|
1517
|
-
)
|
|
1518
|
-
obj.curr_idx = curr_idx
|
|
1519
|
-
obj.indices = indices
|
|
1520
|
-
return obj
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
@register_pytree_node_class
|
|
1524
|
-
class DataGeneratorObservationsMultiPINNs:
|
|
1525
|
-
"""
|
|
1313
|
+
class DataGeneratorObservationsMultiPINNs(eqx.Module):
|
|
1314
|
+
r"""
|
|
1526
1315
|
Despite the class name, it is rather a dataloader from user provided
|
|
1527
1316
|
observations that will be used for the observations loss.
|
|
1528
1317
|
This is the DataGenerator to use when dealing with multiple PINNs
|
|
@@ -1532,141 +1321,123 @@ class DataGeneratorObservationsMultiPINNs:
|
|
|
1532
1321
|
applied in `constraints_system_loss_apply` and in this case the
|
|
1533
1322
|
batch.obs_batch_dict is a dict of obs_batch_dict over which the tree_map
|
|
1534
1323
|
applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
|
|
1324
|
+
|
|
1325
|
+
Parameters
|
|
1326
|
+
----------
|
|
1327
|
+
obs_batch_size : Int
|
|
1328
|
+
The size of the batch of randomly selected observations
|
|
1329
|
+
`obs_batch_size` will be the same for all the
|
|
1330
|
+
elements of the obs dict.
|
|
1331
|
+
NOTE: no check is done BUT users should be careful that
|
|
1332
|
+
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
1333
|
+
`omega_batch_size` or the product of both. In the first case, the
|
|
1334
|
+
present DataGeneratorObservations instance complements an ODEBatch,
|
|
1335
|
+
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1336
|
+
= False). In the second case, `obs_batch_size` =
|
|
1337
|
+
`temporal_batch_size * omega_batch_size` if the present
|
|
1338
|
+
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1339
|
+
with self.cartesian_product = True
|
|
1340
|
+
observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
|
|
1341
|
+
A dict of observed_pinn_in as defined in DataGeneratorObservations.
|
|
1342
|
+
Keys must be that of `u_dict`.
|
|
1343
|
+
If no observation exists for a particular entry of `u_dict` the
|
|
1344
|
+
corresponding key must still exist in observed_pinn_in_dict with
|
|
1345
|
+
value None
|
|
1346
|
+
observed_values_dict : Dict[str, Float[Array, "n_obs, nb_pinn_out"] | None]
|
|
1347
|
+
A dict of observed_values as defined in DataGeneratorObservations.
|
|
1348
|
+
Keys must be that of `u_dict`.
|
|
1349
|
+
If no observation exists for a particular entry of `u_dict` the
|
|
1350
|
+
corresponding key must still exist in observed_values_dict with
|
|
1351
|
+
value None
|
|
1352
|
+
observed_eq_params_dict : Dict[str, Dict[str, Float[Array, "n_obs 1"]]]
|
|
1353
|
+
A dict of observed_eq_params as defined in DataGeneratorObservations.
|
|
1354
|
+
Keys must be that of `u_dict`.
|
|
1355
|
+
**Note**: if no observation exists for a particular entry of `u_dict` the
|
|
1356
|
+
corresponding key must still exist in observed_eq_params_dict with
|
|
1357
|
+
value `{}` (empty dictionnary).
|
|
1358
|
+
key
|
|
1359
|
+
Jax random key to shuffle batches.
|
|
1535
1360
|
"""
|
|
1536
1361
|
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
)
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
obs_batch_size
|
|
1550
|
-
An integer. The size of the batch of randomly selected observations
|
|
1551
|
-
`obs_batch_size` will be the same for all the
|
|
1552
|
-
elements of the obs dict. `obs_batch_size` must be
|
|
1553
|
-
equal to `temporal_batch_size` or `omega_batch_size` or the product
|
|
1554
|
-
of both whether the present DataGeneratorParameter instance
|
|
1555
|
-
complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
|
|
1556
|
-
respectively.
|
|
1557
|
-
observed_pinn_in_dict
|
|
1558
|
-
A dict of observed_pinn_in as defined in DataGeneratorObservations.
|
|
1559
|
-
Keys must be that of `u_dict`.
|
|
1560
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1561
|
-
corresponding key must still exist in observed_pinn_in_dict with
|
|
1562
|
-
value None
|
|
1563
|
-
observed_values_dict
|
|
1564
|
-
A dict of observed_values as defined in DataGeneratorObservations.
|
|
1565
|
-
Keys must be that of `u_dict`.
|
|
1566
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1567
|
-
corresponding key must still exist in observed_values_dict with
|
|
1568
|
-
value None
|
|
1569
|
-
observed_eq_params_dict
|
|
1570
|
-
A dict of observed_eq_params as defined in DataGeneratorObservations.
|
|
1571
|
-
Keys must be that of `u_dict`.
|
|
1572
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1573
|
-
corresponding key must still exist in observed_eq_params_dict with
|
|
1574
|
-
value None
|
|
1575
|
-
data_gen_obs_exists
|
|
1576
|
-
Must be left to `False` when created by the user. Avoids the
|
|
1577
|
-
regeneration the subclasses DataGeneratorObservations
|
|
1578
|
-
at each pytree flattening and unflattening.
|
|
1579
|
-
key
|
|
1580
|
-
Jax random key to sample new time points and to shuffle batches.
|
|
1581
|
-
Optional if data_gen_obs_exists is True
|
|
1582
|
-
"""
|
|
1583
|
-
if (
|
|
1584
|
-
observed_pinn_in_dict is None or observed_values_dict is None
|
|
1585
|
-
) and not data_gen_obs_exists:
|
|
1362
|
+
obs_batch_size: Int
|
|
1363
|
+
observed_pinn_in_dict: Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
|
|
1364
|
+
observed_values_dict: Dict[str, Float[Array, "n_obs nb_pinn_out"] | None]
|
|
1365
|
+
observed_eq_params_dict: Dict[str, Dict[str, Float[Array, "n_obs 1"]]] = eqx.field(
|
|
1366
|
+
default=None, kw_only=True
|
|
1367
|
+
)
|
|
1368
|
+
key: InitVar[Key]
|
|
1369
|
+
|
|
1370
|
+
data_gen_obs: Dict[str, "DataGeneratorObservations"] = eqx.field(init=False)
|
|
1371
|
+
|
|
1372
|
+
def __post_init__(self, key):
|
|
1373
|
+
if self.observed_pinn_in_dict is None or self.observed_values_dict is None:
|
|
1586
1374
|
raise ValueError(
|
|
1587
|
-
"observed_pinn_in_dict and observed_values_dict "
|
|
1588
|
-
|
|
1375
|
+
"observed_pinn_in_dict and observed_values_dict " "must be provided"
|
|
1376
|
+
)
|
|
1377
|
+
if self.observed_pinn_in_dict.keys() != self.observed_values_dict.keys():
|
|
1378
|
+
raise ValueError(
|
|
1379
|
+
"Keys must be the same in observed_pinn_in_dict"
|
|
1380
|
+
" and observed_values_dict"
|
|
1589
1381
|
)
|
|
1590
|
-
self.obs_batch_size = obs_batch_size
|
|
1591
|
-
self.data_gen_obs_exists = data_gen_obs_exists
|
|
1592
1382
|
|
|
1593
|
-
if
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
and observed_pinn_in_dict.keys() != observed_eq_params_dict.keys()
|
|
1602
|
-
):
|
|
1603
|
-
raise ValueError(
|
|
1604
|
-
"Keys must be the same in observed_eq_params_dict"
|
|
1605
|
-
" and observed_pinn_in_dict and observed_values_dict"
|
|
1606
|
-
)
|
|
1607
|
-
if observed_eq_params_dict is None:
|
|
1608
|
-
observed_eq_params_dict = {
|
|
1609
|
-
k: None for k in observed_pinn_in_dict.keys()
|
|
1610
|
-
}
|
|
1611
|
-
|
|
1612
|
-
keys = dict(
|
|
1613
|
-
zip(
|
|
1614
|
-
observed_pinn_in_dict.keys(),
|
|
1615
|
-
jax.random.split(key, len(observed_pinn_in_dict)),
|
|
1616
|
-
)
|
|
1383
|
+
if self.observed_eq_params_dict is None:
|
|
1384
|
+
self.observed_eq_params_dict = {
|
|
1385
|
+
k: {} for k in self.observed_pinn_in_dict.keys()
|
|
1386
|
+
}
|
|
1387
|
+
elif self.observed_pinn_in_dict.keys() != self.observed_eq_params_dict.keys():
|
|
1388
|
+
raise ValueError(
|
|
1389
|
+
f"Keys must be the same in observed_eq_params_dict"
|
|
1390
|
+
f" and observed_pinn_in_dict and observed_values_dict"
|
|
1617
1391
|
)
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
if pinn_in is not None
|
|
1624
|
-
else None
|
|
1625
|
-
),
|
|
1626
|
-
keys,
|
|
1627
|
-
observed_pinn_in_dict,
|
|
1628
|
-
observed_values_dict,
|
|
1629
|
-
observed_eq_params_dict,
|
|
1392
|
+
|
|
1393
|
+
keys = dict(
|
|
1394
|
+
zip(
|
|
1395
|
+
self.observed_pinn_in_dict.keys(),
|
|
1396
|
+
jax.random.split(key, len(self.observed_pinn_in_dict)),
|
|
1630
1397
|
)
|
|
1398
|
+
)
|
|
1399
|
+
self.data_gen_obs = jax.tree_util.tree_map(
|
|
1400
|
+
lambda k, pinn_in, val, eq_params: (
|
|
1401
|
+
DataGeneratorObservations(
|
|
1402
|
+
k, self.obs_batch_size, pinn_in, val, eq_params
|
|
1403
|
+
)
|
|
1404
|
+
if pinn_in is not None
|
|
1405
|
+
else None
|
|
1406
|
+
),
|
|
1407
|
+
keys,
|
|
1408
|
+
self.observed_pinn_in_dict,
|
|
1409
|
+
self.observed_values_dict,
|
|
1410
|
+
self.observed_eq_params_dict,
|
|
1411
|
+
)
|
|
1631
1412
|
|
|
1632
|
-
def obs_batch(self):
|
|
1413
|
+
def obs_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
|
|
1633
1414
|
"""
|
|
1634
1415
|
Returns a dictionary of DataGeneratorObservations.obs_batch with keys
|
|
1635
1416
|
from `u_dict`
|
|
1636
1417
|
"""
|
|
1637
|
-
|
|
1418
|
+
data_gen_and_batch_pytree = jax.tree_util.tree_map(
|
|
1638
1419
|
lambda a: a.get_batch() if a is not None else {},
|
|
1639
1420
|
self.data_gen_obs,
|
|
1640
1421
|
is_leaf=lambda x: isinstance(x, DataGeneratorObservations),
|
|
1641
1422
|
) # note the is_leaf note to traverse the DataGeneratorObservations and
|
|
1642
1423
|
# thus to be able to call the method on the element(s) of
|
|
1643
1424
|
# self.data_gen_obs which are not None
|
|
1425
|
+
new_attribute = jax.tree_util.tree_map(
|
|
1426
|
+
lambda a: a[0],
|
|
1427
|
+
data_gen_and_batch_pytree,
|
|
1428
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
1429
|
+
)
|
|
1430
|
+
new = eqx.tree_at(lambda m: m.data_gen_obs, self, new_attribute)
|
|
1431
|
+
batches = jax.tree_util.tree_map(
|
|
1432
|
+
lambda a: a[1],
|
|
1433
|
+
data_gen_and_batch_pytree,
|
|
1434
|
+
is_leaf=lambda x: isinstance(x, tuple),
|
|
1435
|
+
)
|
|
1644
1436
|
|
|
1645
|
-
|
|
1437
|
+
return new, batches
|
|
1438
|
+
|
|
1439
|
+
def get_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
|
|
1646
1440
|
"""
|
|
1647
1441
|
Generic method to return a batch
|
|
1648
1442
|
"""
|
|
1649
1443
|
return self.obs_batch()
|
|
1650
|
-
|
|
1651
|
-
def tree_flatten(self):
|
|
1652
|
-
# because a dict with "str" keys cannot go in the children (jittable)
|
|
1653
|
-
# attributes, we need to separate it in two and recreate the zip in the
|
|
1654
|
-
# tree_unflatten
|
|
1655
|
-
children = self.data_gen_obs.values()
|
|
1656
|
-
aux_data = {
|
|
1657
|
-
"obs_batch_size": self.obs_batch_size,
|
|
1658
|
-
"data_gen_obs_keys": self.data_gen_obs.keys(),
|
|
1659
|
-
}
|
|
1660
|
-
return (children, aux_data)
|
|
1661
|
-
|
|
1662
|
-
@classmethod
|
|
1663
|
-
def tree_unflatten(cls, aux_data, children):
|
|
1664
|
-
(data_gen_obs_values) = children
|
|
1665
|
-
obj = cls(
|
|
1666
|
-
observed_pinn_in_dict=None,
|
|
1667
|
-
observed_values_dict=None,
|
|
1668
|
-
data_gen_obs_exists=True,
|
|
1669
|
-
obs_batch_size=aux_data["obs_batch_size"],
|
|
1670
|
-
)
|
|
1671
|
-
obj.data_gen_obs = dict(zip(aux_data["data_gen_obs_keys"], data_gen_obs_values))
|
|
1672
|
-
return obj
|