jinns 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +103 -65
- jinns/loss/_LossPDE.py +145 -77
- jinns/loss/_abstract_loss.py +64 -6
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +4 -1
- jinns/solver/_utils.py +17 -11
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/METADATA +14 -4
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/RECORD +28 -28
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
|
@@ -5,6 +5,8 @@ Define the DataGenerators modules
|
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
from functools import partial
|
|
9
|
+
from itertools import zip_longest
|
|
8
10
|
import equinox as eqx
|
|
9
11
|
import jax
|
|
10
12
|
import jax.numpy as jnp
|
|
@@ -27,6 +29,84 @@ if TYPE_CHECKING:
|
|
|
27
29
|
# before hand: this is not practical, let us not get mad at this
|
|
28
30
|
|
|
29
31
|
|
|
32
|
+
def _merge_dict_arguments(fun, fixed_args):
|
|
33
|
+
"""
|
|
34
|
+
A decorator function that transforms a tuple dictionary with one key
|
|
35
|
+
in a function call with a big merged unpacked dict.
|
|
36
|
+
|
|
37
|
+
This is used for a dynamic construction of a tree map call, where an arbitrary number of arguments
|
|
38
|
+
are fixed before the tree map call.
|
|
39
|
+
|
|
40
|
+
We need this because the function that needs to be called
|
|
41
|
+
is kw only, but jax tree map does not support keyword only,
|
|
42
|
+
so we pass through this decorator
|
|
43
|
+
|
|
44
|
+
Example of usage:
|
|
45
|
+
```
|
|
46
|
+
def f(*, a, b):
|
|
47
|
+
'''
|
|
48
|
+
kw only function that we would like to call in a jax.tree.map
|
|
49
|
+
but we do not know which arguments will be fixed before runtime
|
|
50
|
+
and jax.tree.map does not allow for kw
|
|
51
|
+
'''
|
|
52
|
+
# Do whatever you need to
|
|
53
|
+
|
|
54
|
+
# Then, by any means (at runtime), we determine the arguments that will be
|
|
55
|
+
# fixed during the jax.tree.map, we store them as a tuple of strings, as
|
|
56
|
+
# example:
|
|
57
|
+
fixed_args = ("a")
|
|
58
|
+
# and we actually fix them for all this to have some sense:
|
|
59
|
+
f = partial(f, a=observed_pinn_in[0])
|
|
60
|
+
|
|
61
|
+
# We also need to construct a tuple of tuples of dictionaries with only one
|
|
62
|
+
# key. We want to call the jax.tree.map for each tuple in the tuple. In
|
|
63
|
+
# each tuple we have dictionaries with one key (one dict for each for the
|
|
64
|
+
# argument of `f` which serves as key). The dictionaries will be handled by
|
|
65
|
+
# the decorator:
|
|
66
|
+
tree_map_args = ( # jax.tree.map over each element of this tuple
|
|
67
|
+
({"a": None}, # most expected value since a is fixed
|
|
68
|
+
{"b": observed_values[0]}), # any useful value.
|
|
69
|
+
|
|
70
|
+
({"a": None}, {"b":observed_values[1]),
|
|
71
|
+
({"a": None}, {"b":observed_values[2]}),
|
|
72
|
+
)
|
|
73
|
+
# Then we can call
|
|
74
|
+
jax.tree.map(
|
|
75
|
+
_merge_dict_arguments(f, fixed_args),
|
|
76
|
+
tree_map_args,
|
|
77
|
+
is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)), # force iteration over outer tuple only
|
|
78
|
+
)
|
|
79
|
+
```
|
|
80
|
+
In the code sample above, we see that the gain really lies in the fact that
|
|
81
|
+
`a=observed_pinn_in[0]` is not duplicated and this saves memory (in most of
|
|
82
|
+
runtime) while still enabling the jax.tree.map call (which however,
|
|
83
|
+
duplicates `a`
|
|
84
|
+
just for the computation time :)). Indeed, the direct way would be to
|
|
85
|
+
construct something like:
|
|
86
|
+
```
|
|
87
|
+
tree_map_args = (
|
|
88
|
+
(observed_pinn_in[0], observed_values[0]),
|
|
89
|
+
(observed_pinn_in[0], observed_values[1]),
|
|
90
|
+
(observed_pinn_in[0], observed_values[2]),
|
|
91
|
+
)
|
|
92
|
+
```
|
|
93
|
+
which could be a burden for the whole runtime
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def wrapper(tuple_of_dict):
|
|
97
|
+
d = {}
|
|
98
|
+
# the for loop below is needed because there is no unpack operator
|
|
99
|
+
# authorized inside a comprehension for now: https://stackoverflow.com/a/37584733
|
|
100
|
+
for d_ in tuple_of_dict:
|
|
101
|
+
if len(d_.keys()) != 1:
|
|
102
|
+
raise ValueError("Problem here, we expect 1-key-dict")
|
|
103
|
+
if list(d_.keys())[0] not in fixed_args:
|
|
104
|
+
d.update(d_)
|
|
105
|
+
return fun(**d)
|
|
106
|
+
|
|
107
|
+
return wrapper
|
|
108
|
+
|
|
109
|
+
|
|
30
110
|
class DGObservedParams(metaclass=DictToModuleMeta):
|
|
31
111
|
"""
|
|
32
112
|
However, static type checkers cannot know that DGObservedParams inherit from
|
|
@@ -45,20 +125,22 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
45
125
|
----------
|
|
46
126
|
key : PRNGKeyArray
|
|
47
127
|
Jax random key to shuffle batches
|
|
48
|
-
obs_batch_size : int | None
|
|
128
|
+
obs_batch_size : tuple[int | None, ...]
|
|
49
129
|
The size of the batch of randomly selected points among
|
|
50
130
|
the `n` points. If None, no minibatch are used.
|
|
51
|
-
observed_pinn_in : Float[Array, " n_obs nb_pinn_in"]
|
|
131
|
+
observed_pinn_in : Float[Array, " n_obs nb_pinn_in"] | tuple[Float[Array, " n_obs nb_pinn_in"], ...]
|
|
52
132
|
Observed values corresponding to the input of the PINN
|
|
53
133
|
(eg. the time at which we recorded the observations). The first
|
|
54
134
|
dimension must corresponds to the number of observed_values.
|
|
55
135
|
The second dimension depends on the input dimension of the PINN,
|
|
56
136
|
that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
|
|
57
137
|
for non-stationnary PDE.
|
|
58
|
-
|
|
138
|
+
Can be a tuple of such arguments to support multidatasets, see below.
|
|
139
|
+
observed_values : Float[Array, " n_obs, nb_pinn_out"] | tuple[Float[Array, " n_obs, nb_pinn_out"], ...]
|
|
59
140
|
Observed values that the PINN should learn to fit. The first
|
|
60
141
|
dimension must be aligned with observed_pinn_in.
|
|
61
|
-
|
|
142
|
+
Can be a tuple of such arguments to support multidatasets, see below.
|
|
143
|
+
observed_eq_params : dict[str, Float[Array, " n_obs 1"]] | tuple[dict[str, Float[Array, " n_obs 1"], ...]], default=None
|
|
62
144
|
A dict with keys corresponding to
|
|
63
145
|
the parameter name. The keys must match the keys in
|
|
64
146
|
`params["eq_params"]`, ie., if only some parameters are observed, other
|
|
@@ -66,6 +148,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
66
148
|
with values corresponding to the parameter value for which we also
|
|
67
149
|
have observed_pinn_in and observed_values. Hence the first
|
|
68
150
|
dimension must be aligned with observed_pinn_in and observed_values.
|
|
151
|
+
Can be a tuple of such arguments to support multidatasets, see below.
|
|
69
152
|
Optional argument.
|
|
70
153
|
sharding_device : jax.sharding.Sharding, default=None
|
|
71
154
|
Default None. An optional sharding object to constraint the storage
|
|
@@ -75,73 +158,303 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
75
158
|
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
76
159
|
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
77
160
|
arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
|
|
161
|
+
|
|
162
|
+
** New in jinns v1.7.1:** We provide the possibility of specifying several
|
|
163
|
+
datasets of observations, this can serve a variety of purposes, as for
|
|
164
|
+
example, provided different observations for different channels of the
|
|
165
|
+
solution (by using the `obs_slice` attribute of Loss objects, see the
|
|
166
|
+
notebook on this topic in the documentation). To
|
|
167
|
+
provide several datasets, it suffices to pass `observed_values` or
|
|
168
|
+
`observed_pinn_in` or `observed_eq_params` as tuples. If you have
|
|
169
|
+
several `observed_values` but the same `observed_eq_params` and or the
|
|
170
|
+
same `observed_pinn_in`, the two latter should not be duplicated in
|
|
171
|
+
tuples of the same length of `observed_values`. The last sentence
|
|
172
|
+
remains true when interchanging the terms `observed_values`,
|
|
173
|
+
`observed_pinn_in` and `observed_eq_params` in any position.
|
|
174
|
+
This is not a syntaxic sugar.
|
|
175
|
+
This is a real necessity to avoid duplicating data in the dataclasses
|
|
176
|
+
attributes, to gain speed and memory performances. This
|
|
177
|
+
internally handled with dynamic freezing of non vectorized
|
|
178
|
+
arguments (see code...). Note that the data are still duplicated at
|
|
179
|
+
the moment of the vectorial operation (jax tree map), and at during
|
|
180
|
+
that operation memory reaches its peaks and the memory usage is the
|
|
181
|
+
same in the DG with duplication and DG without duplication. However,
|
|
182
|
+
after that, the memory usage goes down again for the version without
|
|
183
|
+
duplication so we avoid memory overflow that could happen, for example,
|
|
184
|
+
when computing costly dynamic losses while storing a uselessly big
|
|
185
|
+
DataGeneratorObservations
|
|
186
|
+
(see the test script at
|
|
187
|
+
`jinns/tests/dataGenerator_tests/profiling_DataGeneratorObservations.py`
|
|
188
|
+
).
|
|
189
|
+
|
|
78
190
|
"""
|
|
79
191
|
|
|
80
192
|
key: PRNGKeyArray
|
|
81
|
-
obs_batch_size: int | None = eqx.field(static=True)
|
|
82
|
-
observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
|
|
83
|
-
observed_values: Float[Array, " n_obs nb_pinn_out"]
|
|
84
|
-
observed_eq_params: eqx.Module | None
|
|
85
|
-
sharding_device: jax.sharding.Sharding | None
|
|
193
|
+
obs_batch_size: tuple[int | None, ...] = eqx.field(static=True)
|
|
194
|
+
observed_pinn_in: tuple[Float[Array, " n_obs nb_pinn_in"], ...]
|
|
195
|
+
observed_values: tuple[Float[Array, " n_obs nb_pinn_out"], ...]
|
|
196
|
+
observed_eq_params: tuple[eqx.Module | None, ...]
|
|
197
|
+
sharding_device: jax.sharding.Sharding | None = eqx.field(static=True)
|
|
86
198
|
|
|
87
|
-
n: int = eqx.field(init=False, static=True)
|
|
88
|
-
curr_idx: int = eqx.field(init=False)
|
|
89
|
-
indices: Array = eqx.field(init=False)
|
|
199
|
+
n: tuple[int, ...] = eqx.field(init=False, static=True)
|
|
200
|
+
curr_idx: tuple[int, ...] = eqx.field(init=False)
|
|
201
|
+
indices: tuple[Array, ...] = eqx.field(init=False)
|
|
90
202
|
|
|
91
203
|
def __init__(
|
|
92
204
|
self,
|
|
93
205
|
*,
|
|
94
206
|
key: PRNGKeyArray,
|
|
95
|
-
obs_batch_size: int | None = None,
|
|
96
|
-
observed_pinn_in: Float[Array, " n_obs nb_pinn_in"],
|
|
97
|
-
|
|
98
|
-
|
|
207
|
+
obs_batch_size: tuple[int | None, ...] | int | None = None,
|
|
208
|
+
observed_pinn_in: tuple[Float[Array, " n_obs nb_pinn_in"], ...]
|
|
209
|
+
| Float[Array, " n_obs nb_pinn_in"],
|
|
210
|
+
observed_values: tuple[Float[Array, " n_obs nb_pinn_out"], ...]
|
|
211
|
+
| Float[Array, " n_obs nb_pinn_out"],
|
|
212
|
+
observed_eq_params: tuple[InputEqParams, ...] | InputEqParams | None = None,
|
|
99
213
|
sharding_device: jax.sharding.Sharding | None = None,
|
|
100
214
|
) -> None:
|
|
215
|
+
""" """
|
|
101
216
|
super().__init__()
|
|
102
217
|
self.key = key
|
|
103
|
-
|
|
218
|
+
|
|
219
|
+
if not isinstance(observed_values, tuple):
|
|
220
|
+
observed_values = (observed_values,)
|
|
221
|
+
if not isinstance(observed_pinn_in, tuple):
|
|
222
|
+
observed_pinn_in = (observed_pinn_in,)
|
|
223
|
+
if observed_eq_params is not None:
|
|
224
|
+
if not isinstance(observed_eq_params, tuple):
|
|
225
|
+
observed_eq_params = (observed_eq_params,)
|
|
226
|
+
else:
|
|
227
|
+
observed_eq_params = (None,)
|
|
228
|
+
|
|
229
|
+
# now if values, pinn_in, and eq_params does not have same length (as
|
|
230
|
+
# tuples), we must find the longest one and the other either must be
|
|
231
|
+
# length 1 or must be the same length as the longest
|
|
232
|
+
len_longest_tuple = max(
|
|
233
|
+
map(len, (observed_values, observed_pinn_in, observed_eq_params))
|
|
234
|
+
)
|
|
235
|
+
longest_tuple = max(
|
|
236
|
+
(observed_values, observed_pinn_in, observed_eq_params), key=len
|
|
237
|
+
)
|
|
238
|
+
if len(observed_values) != 1 and len(observed_values) != len_longest_tuple:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
"If observed_values is a tuple, it should"
|
|
241
|
+
" be of length 1 (one array, the same for"
|
|
242
|
+
" all the pinn_in and eq_params entries) or be of the same"
|
|
243
|
+
" length as the longest tuple of entries (1 to 1 matching)"
|
|
244
|
+
)
|
|
245
|
+
if len(observed_pinn_in) != 1 and len(observed_pinn_in) != len_longest_tuple:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"If observed_pinn_in is a tuple, it should"
|
|
248
|
+
" be of length 1 (one array, the same for"
|
|
249
|
+
" all the values and eq_params entries) or be of the same"
|
|
250
|
+
" length as the longest tuple of entries (1 to 1 matching)"
|
|
251
|
+
)
|
|
252
|
+
if (
|
|
253
|
+
len(observed_eq_params) != 1
|
|
254
|
+
and len(observed_eq_params) != len_longest_tuple
|
|
255
|
+
):
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"If observed_eq_params is a tuple, it should"
|
|
258
|
+
" be of length 1 (one array, the same for"
|
|
259
|
+
" all the values and pinn_in entries) or be of the same"
|
|
260
|
+
" length as the longest tuple of entries (1 to 1 matching)"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
### Start check first axis
|
|
264
|
+
|
|
265
|
+
def check_first_axis(*, values, pinn_in_array):
|
|
266
|
+
if values.shape[0] != pinn_in_array.shape[0]:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
"Each matching elements of self.observed_pinn_in and self.observed_values must have same first axis"
|
|
269
|
+
)
|
|
270
|
+
return values
|
|
271
|
+
|
|
272
|
+
tree_map_args = tuple(
|
|
273
|
+
({"values": v}, {"pinn_in_array": p})
|
|
274
|
+
for v, p in zip_longest(observed_values, observed_pinn_in)
|
|
275
|
+
)
|
|
276
|
+
fixed_args = ()
|
|
277
|
+
if len(observed_values) != len(observed_pinn_in):
|
|
278
|
+
if len(observed_pinn_in) == 1:
|
|
279
|
+
check_first_axis = partial(
|
|
280
|
+
check_first_axis, pinn_in_array=observed_pinn_in[0]
|
|
281
|
+
)
|
|
282
|
+
fixed_args = fixed_args + ("pinn_in_array",)
|
|
283
|
+
if len(observed_values) == 1:
|
|
284
|
+
check_first_axis = partial(check_first_axis, values=observed_values[0])
|
|
285
|
+
fixed_args = fixed_args + ("values",)
|
|
286
|
+
# ... and then we do the tree map. Note that in the tree.map below,
|
|
287
|
+
# self.observed_eq_params can have None leaves
|
|
288
|
+
# tree_map_args is a tuple of tuple dicts: 1) outer tuples are those we
|
|
289
|
+
# will vectorize over 2) inside tuples to be able to unpack
|
|
290
|
+
# dynamically (i.e. varying nb of elements to pass to fun)
|
|
291
|
+
# 3) then the dicts are merged to feed the kw only function
|
|
292
|
+
# tree.map cannot directly feed a kw only
|
|
293
|
+
# function such as check_first_axis (so we pass through
|
|
294
|
+
# the decorator)
|
|
295
|
+
jax.tree.map(
|
|
296
|
+
_merge_dict_arguments(check_first_axis, fixed_args),
|
|
297
|
+
tree_map_args,
|
|
298
|
+
is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
### End check first axis
|
|
302
|
+
|
|
104
303
|
self.observed_pinn_in = observed_pinn_in
|
|
105
304
|
self.observed_values = observed_values
|
|
305
|
+
if observed_eq_params == (None,):
|
|
306
|
+
self.observed_eq_params = observed_eq_params # pyright: ignore
|
|
307
|
+
# (this is resolved later on one instanciating DGObservedParams)
|
|
308
|
+
else:
|
|
309
|
+
self.observed_eq_params = jax.tree.map(
|
|
310
|
+
lambda d: {
|
|
311
|
+
k: v[:, None] if len(v.shape) == 1 else v for k, v in d.items()
|
|
312
|
+
},
|
|
313
|
+
observed_eq_params,
|
|
314
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
315
|
+
)
|
|
106
316
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
317
|
+
self.observed_pinn_in = jax.tree.map(
|
|
318
|
+
lambda x: x[:, None] if len(x.shape) == 1 else x, self.observed_pinn_in
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
self.observed_values = jax.tree.map(
|
|
322
|
+
lambda x: x[:, None] if len(x.shape) == 1 else x, self.observed_values
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
### Start check first axis 2
|
|
326
|
+
def check_first_axis2(*, eq_params_dict, pinn_in_array):
|
|
327
|
+
if eq_params_dict is not None:
|
|
328
|
+
for _, v in eq_params_dict.items():
|
|
329
|
+
if v.shape[0] != pinn_in_array.shape[0]:
|
|
330
|
+
raise ValueError(
|
|
331
|
+
"Each matching elements of self.observed_pinn_in and self.observed_eq_params must have the same first axis"
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# the following tree_map_args will work if all lengths are equal either
|
|
335
|
+
# 1 or more
|
|
336
|
+
tree_map_args = tuple(
|
|
337
|
+
({"eq_params_dict": e}, {"pinn_in_array": p})
|
|
338
|
+
for e, p in zip_longest(self.observed_eq_params, self.observed_pinn_in)
|
|
339
|
+
)
|
|
340
|
+
fixed_args = ()
|
|
341
|
+
if len(self.observed_eq_params) != len(self.observed_pinn_in):
|
|
342
|
+
if len(self.observed_pinn_in) == 1:
|
|
343
|
+
check_first_axis2 = partial(
|
|
344
|
+
check_first_axis2, pinn_in_array=self.observed_pinn_in[0]
|
|
345
|
+
)
|
|
346
|
+
fixed_args = fixed_args + ("pinn_in_array",)
|
|
347
|
+
|
|
348
|
+
if len(self.observed_eq_params) == 1:
|
|
349
|
+
check_first_axis2 = partial(
|
|
350
|
+
check_first_axis2, eq_params_dict=self.observed_eq_params[0]
|
|
351
|
+
)
|
|
352
|
+
fixed_args = fixed_args + ("eq_params_dict",)
|
|
353
|
+
jax.tree.map(
|
|
354
|
+
_merge_dict_arguments(
|
|
355
|
+
check_first_axis2, fixed_args
|
|
356
|
+
), # https://stackoverflow.com/a/42421497
|
|
357
|
+
tree_map_args,
|
|
358
|
+
is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
### End check first axis 2
|
|
362
|
+
|
|
363
|
+
### Start check ndim
|
|
364
|
+
|
|
365
|
+
def check_ndim(*, values, pinn_in_array, eq_params_dict):
|
|
366
|
+
if values.ndim > 2:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
"Each element of self.observed_pinn_in must have 2 dimensions"
|
|
369
|
+
)
|
|
370
|
+
if pinn_in_array.ndim > 2:
|
|
371
|
+
raise ValueError(
|
|
372
|
+
"Each element of self.observed_values must have 2 dimensions"
|
|
373
|
+
)
|
|
374
|
+
if eq_params_dict is not None:
|
|
375
|
+
for _, v in eq_params_dict.items():
|
|
376
|
+
if v.ndim > 2:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
"Each value of observed_eq_params must have 2 dimensions"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# the following tree_map_args will work if all lengths are equal either
|
|
382
|
+
# 1 or more
|
|
383
|
+
tree_map_args = tuple(
|
|
384
|
+
({"eq_params_dict": e}, {"pinn_in_array": p}, {"values": v})
|
|
385
|
+
for e, p, v in zip_longest(
|
|
386
|
+
self.observed_eq_params, self.observed_pinn_in, self.observed_values
|
|
110
387
|
)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
388
|
+
)
|
|
389
|
+
# now, if some shape are different, it can only be because there are 1
|
|
390
|
+
# while we expect a fixed n (thanks to the early tests above)
|
|
391
|
+
# then we must fix the arguments that are single leaf pytree
|
|
392
|
+
# and keep track of the arguments that are fixed to be able to remove
|
|
393
|
+
# them in the wrapper
|
|
394
|
+
fixed_args = ()
|
|
395
|
+
if len(self.observed_eq_params) != len(self.observed_pinn_in) or len(
|
|
396
|
+
self.observed_eq_params
|
|
397
|
+
) != len(self.observed_values):
|
|
398
|
+
if len(self.observed_pinn_in) == 1:
|
|
399
|
+
check_ndim = partial(check_ndim, pinn_in_array=self.observed_pinn_in[0])
|
|
400
|
+
fixed_args = fixed_args + ("pinn_in_array",)
|
|
401
|
+
if len(self.observed_eq_params) == 1:
|
|
402
|
+
check_ndim = partial(
|
|
403
|
+
check_ndim, eq_params_dict=self.observed_eq_params[0]
|
|
404
|
+
)
|
|
405
|
+
fixed_args = fixed_args + ("eq_params_dict",)
|
|
406
|
+
if len(self.observed_values) == 1:
|
|
407
|
+
check_ndim = partial(check_ndim, values=self.observed_values[0])
|
|
408
|
+
fixed_args = fixed_args + ("values",)
|
|
119
409
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
#
|
|
137
|
-
#
|
|
138
|
-
|
|
139
|
-
|
|
410
|
+
jax.tree.map(
|
|
411
|
+
_merge_dict_arguments(check_ndim, fixed_args),
|
|
412
|
+
tree_map_args,
|
|
413
|
+
is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
|
|
414
|
+
)
|
|
415
|
+
### End check ndim
|
|
416
|
+
|
|
417
|
+
# longest_tuple will be used for correct jax tree map broadcast. Note
|
|
418
|
+
# that even though self.observed_pinn_in and self.observed_values and
|
|
419
|
+
# self.observed_eq_params does
|
|
420
|
+
# not have the same len (as tuples), their components (jnp.arrays) do
|
|
421
|
+
# have the same first axis. This is worked out by all the previous
|
|
422
|
+
# checks
|
|
423
|
+
self.n = jax.tree.map(
|
|
424
|
+
lambda o: o.shape[0],
|
|
425
|
+
tuple(_ for _ in jax.tree.leaves(longest_tuple)), # jax.tree.leaves
|
|
426
|
+
# because if longest_tuple is eq_params then it is a dict but we do
|
|
427
|
+
# not want self.n to have the dict tree structure
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
if isinstance(obs_batch_size, int) or obs_batch_size is None:
|
|
431
|
+
self.obs_batch_size = jax.tree.map(
|
|
432
|
+
lambda _: obs_batch_size,
|
|
433
|
+
tuple(_ for _ in jax.tree.leaves(longest_tuple)), # jax tree leaves
|
|
434
|
+
# because if longest_tuple is eq_params then it is a dict but we do
|
|
435
|
+
# not want self.n to have the dict tree structure
|
|
140
436
|
)
|
|
437
|
+
elif isinstance(obs_batch_size, tuple):
|
|
438
|
+
if len(obs_batch_size) != len_longest_tuple and len(obs_batch_size) != 1:
|
|
439
|
+
raise ValueError(
|
|
440
|
+
"If obs_batch_size is a tuple, it must me"
|
|
441
|
+
" of length 1 or of length equal to the"
|
|
442
|
+
" maximum length between values, pinn_in and"
|
|
443
|
+
" eq_params."
|
|
444
|
+
)
|
|
445
|
+
self.obs_batch_size = obs_batch_size
|
|
141
446
|
else:
|
|
142
|
-
|
|
447
|
+
raise ValueError("obs_batch_size must be an int, a tuple or None")
|
|
143
448
|
|
|
144
|
-
|
|
449
|
+
# After all the checks
|
|
450
|
+
# Convert the dict of observed parameters to the internal
|
|
451
|
+
# `DGObservedParams`
|
|
452
|
+
# class used by Jinns.
|
|
453
|
+
self.observed_eq_params = tuple(
|
|
454
|
+
DGObservedParams(o_, "DGObservedParams")
|
|
455
|
+
for o_ in self.observed_eq_params
|
|
456
|
+
if o_ is not None
|
|
457
|
+
)
|
|
145
458
|
|
|
146
459
|
self.sharding_device = sharding_device
|
|
147
460
|
if self.sharding_device is not None:
|
|
@@ -155,28 +468,39 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
155
468
|
self.observed_eq_params, self.sharding_device
|
|
156
469
|
)
|
|
157
470
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
self.
|
|
471
|
+
# When self.obs_batch_size leaf is None we will have self.curr_idx leaf
|
|
472
|
+
# to None. (Previous behaviour would put an unused self.curr_idx to 0)
|
|
473
|
+
self.curr_idx = jax.tree.map(
|
|
474
|
+
lambda bs, n: bs + n if bs is not None else None,
|
|
475
|
+
self.obs_batch_size,
|
|
476
|
+
self.n,
|
|
477
|
+
is_leaf=lambda x: x is None,
|
|
478
|
+
)
|
|
163
479
|
# For speed and to avoid duplicating data what is really
|
|
164
480
|
# shuffled is a vector of indices
|
|
481
|
+
self.indices = jax.tree.map(jnp.arange, self.n)
|
|
165
482
|
if self.sharding_device is not None:
|
|
166
483
|
self.indices = jax.lax.with_sharding_constraint(
|
|
167
|
-
|
|
484
|
+
self.indices, self.sharding_device
|
|
168
485
|
)
|
|
169
|
-
else:
|
|
170
|
-
self.indices = jnp.arange(self.n)
|
|
171
486
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
487
|
+
if not isinstance(self.key, tuple):
|
|
488
|
+
# recall post_init is the only place with _init_ where we can set
|
|
489
|
+
# self attribute in a in-place way
|
|
490
|
+
self.key = jax.tree.unflatten(
|
|
491
|
+
jax.tree.structure(self.n),
|
|
492
|
+
jax.random.split(self.key, len(jax.tree.leaves(self.n))),
|
|
493
|
+
)
|
|
176
494
|
|
|
177
495
|
def _get_operands(
|
|
178
496
|
self,
|
|
179
|
-
) -> tuple[
|
|
497
|
+
) -> tuple[
|
|
498
|
+
tuple[PRNGKeyArray, ...],
|
|
499
|
+
tuple[Int[Array, " n"], ...],
|
|
500
|
+
tuple[int, ...],
|
|
501
|
+
tuple[int | None, ...],
|
|
502
|
+
None,
|
|
503
|
+
]:
|
|
180
504
|
return (
|
|
181
505
|
self.key,
|
|
182
506
|
self.indices,
|
|
@@ -185,59 +509,143 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
185
509
|
None,
|
|
186
510
|
)
|
|
187
511
|
|
|
512
|
+
@staticmethod
|
|
188
513
|
def obs_batch(
|
|
189
|
-
|
|
190
|
-
|
|
514
|
+
*,
|
|
515
|
+
n,
|
|
516
|
+
obs_batch_size,
|
|
517
|
+
observed_pinn_in,
|
|
518
|
+
observed_values,
|
|
519
|
+
observed_eq_params,
|
|
520
|
+
curr_idx,
|
|
521
|
+
key,
|
|
522
|
+
indices,
|
|
523
|
+
) -> tuple[PRNGKeyArray, Array, Int, ObsBatchDict]:
|
|
191
524
|
"""
|
|
192
525
|
Return an update DataGeneratorObservations instance and an ObsBatchDict
|
|
193
526
|
"""
|
|
194
|
-
if
|
|
527
|
+
if obs_batch_size is None or obs_batch_size == n:
|
|
195
528
|
# Avoid unnecessary reshuffling
|
|
196
|
-
return
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
529
|
+
return (
|
|
530
|
+
key,
|
|
531
|
+
indices,
|
|
532
|
+
curr_idx,
|
|
533
|
+
ObsBatchDict(
|
|
534
|
+
{
|
|
535
|
+
"pinn_in": observed_pinn_in,
|
|
536
|
+
"val": observed_values,
|
|
537
|
+
"eq_params": observed_eq_params,
|
|
538
|
+
}
|
|
539
|
+
),
|
|
202
540
|
)
|
|
203
541
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
542
|
+
new_key, new_indices, new_curr_idx = _reset_or_increment(
|
|
543
|
+
curr_idx + obs_batch_size,
|
|
544
|
+
n,
|
|
545
|
+
(key, indices, curr_idx, obs_batch_size, None), # type: ignore
|
|
208
546
|
# ignore since the case self.obs_batch_size is None has been
|
|
209
547
|
# handled above
|
|
210
548
|
)
|
|
211
|
-
new = eqx.tree_at(
|
|
212
|
-
lambda m: (m.key, m.indices, m.curr_idx), # type: ignore
|
|
213
|
-
self,
|
|
214
|
-
new_attributes,
|
|
215
|
-
)
|
|
216
549
|
|
|
217
550
|
minib_indices = jax.lax.dynamic_slice(
|
|
218
|
-
|
|
219
|
-
start_indices=(
|
|
220
|
-
slice_sizes=(
|
|
551
|
+
new_indices,
|
|
552
|
+
start_indices=(new_curr_idx,),
|
|
553
|
+
slice_sizes=(obs_batch_size,),
|
|
221
554
|
)
|
|
222
555
|
|
|
223
556
|
obs_batch: ObsBatchDict = {
|
|
224
557
|
"pinn_in": jnp.take(
|
|
225
|
-
|
|
558
|
+
observed_pinn_in, minib_indices, unique_indices=True, axis=0
|
|
226
559
|
),
|
|
227
560
|
"val": jnp.take(
|
|
228
|
-
|
|
561
|
+
observed_values, minib_indices, unique_indices=True, axis=0
|
|
229
562
|
),
|
|
230
563
|
"eq_params": jax.tree_util.tree_map(
|
|
231
564
|
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0), # type: ignore
|
|
232
|
-
|
|
565
|
+
observed_eq_params,
|
|
233
566
|
),
|
|
234
567
|
}
|
|
235
|
-
return
|
|
568
|
+
return new_key, new_indices, new_curr_idx, obs_batch
|
|
236
569
|
|
|
237
570
|
def get_batch(
|
|
238
571
|
self,
|
|
239
|
-
) -> tuple[Self, ObsBatchDict]:
|
|
572
|
+
) -> tuple[Self, tuple[ObsBatchDict, ...]]:
|
|
240
573
|
"""
|
|
241
574
|
Generic method to return a batch
|
|
242
575
|
"""
|
|
243
|
-
|
|
576
|
+
# the following tree map over DataGeneratorObservations.obs_batch, must
|
|
577
|
+
# be handled with pre-fixed arguments when, for memory reason,
|
|
578
|
+
# observed_pinn_in or observed_values or observed_eq_params have not
|
|
579
|
+
# does not have the same length. If all tuples are of size 1, this
|
|
580
|
+
# should work totally transparently
|
|
581
|
+
args = (
|
|
582
|
+
self.observed_eq_params,
|
|
583
|
+
self.observed_pinn_in,
|
|
584
|
+
self.observed_values,
|
|
585
|
+
self.n,
|
|
586
|
+
self.obs_batch_size,
|
|
587
|
+
self.curr_idx,
|
|
588
|
+
self.key,
|
|
589
|
+
self.indices,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
tree_map_args = tuple(
|
|
593
|
+
(
|
|
594
|
+
{"observed_eq_params": e},
|
|
595
|
+
{"observed_pinn_in": p},
|
|
596
|
+
{"observed_values": v},
|
|
597
|
+
{"n": n},
|
|
598
|
+
{"obs_batch_size": b},
|
|
599
|
+
{"curr_idx": c},
|
|
600
|
+
{"key": k},
|
|
601
|
+
{"indices": i},
|
|
602
|
+
)
|
|
603
|
+
for e, p, v, n, b, c, k, i in zip_longest(*args)
|
|
604
|
+
)
|
|
605
|
+
fixed_args = ()
|
|
606
|
+
obs_batch_fun = DataGeneratorObservations.obs_batch
|
|
607
|
+
if len(set(map(len, args))) > 1: # at least 2 lengths differ
|
|
608
|
+
# but since values, pinn_in and equations are the arguments that
|
|
609
|
+
# generates all the others, it suffices to potentially fix the
|
|
610
|
+
# former
|
|
611
|
+
if len(self.observed_pinn_in) == 1:
|
|
612
|
+
obs_batch_fun = partial(
|
|
613
|
+
obs_batch_fun, observed_pinn_in=self.observed_pinn_in[0]
|
|
614
|
+
)
|
|
615
|
+
fixed_args = fixed_args + ("observed_pinn_in",)
|
|
616
|
+
if len(self.observed_eq_params) == 1:
|
|
617
|
+
obs_batch_fun = partial(
|
|
618
|
+
obs_batch_fun, observed_eq_params=self.observed_eq_params[0]
|
|
619
|
+
)
|
|
620
|
+
fixed_args = fixed_args + ("observed_eq_params",)
|
|
621
|
+
if len(self.observed_values) == 1:
|
|
622
|
+
obs_batch_fun = partial(
|
|
623
|
+
obs_batch_fun, observed_values=self.observed_values[0]
|
|
624
|
+
)
|
|
625
|
+
fixed_args = fixed_args + ("observed_values",)
|
|
626
|
+
|
|
627
|
+
ret = jax.tree.map(
|
|
628
|
+
_merge_dict_arguments(obs_batch_fun, fixed_args),
|
|
629
|
+
tree_map_args,
|
|
630
|
+
is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
|
|
631
|
+
)
|
|
632
|
+
new_key = jax.tree.map(
|
|
633
|
+
lambda l: l[0], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
|
|
634
|
+
) # we must not traverse the second level
|
|
635
|
+
new_indices = jax.tree.map(
|
|
636
|
+
lambda l: l[1], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
|
|
637
|
+
)
|
|
638
|
+
new_curr_idx = jax.tree.map(
|
|
639
|
+
lambda l: l[2], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
|
|
640
|
+
)
|
|
641
|
+
obs_batch_tuple = jax.tree.map(
|
|
642
|
+
lambda l: l[3], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
new = eqx.tree_at(
|
|
646
|
+
lambda m: (m.key, m.indices, m.curr_idx),
|
|
647
|
+
self,
|
|
648
|
+
(new_key, new_indices, new_curr_idx),
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
return new, obs_batch_tuple
|