jinns 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,8 @@ Define the DataGenerators modules
5
5
  from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
+ from functools import partial
9
+ from itertools import zip_longest
8
10
  import equinox as eqx
9
11
  import jax
10
12
  import jax.numpy as jnp
@@ -27,6 +29,84 @@ if TYPE_CHECKING:
27
29
  # before hand: this is not practical, let us not get mad at this
28
30
 
29
31
 
32
+ def _merge_dict_arguments(fun, fixed_args):
33
+ """
34
+ A decorator function that transforms a tuple dictionary with one key
35
+ in a function call with a big merged unpacked dict.
36
+
37
+ This is used for a dynamic construction of a tree map call, where an arbitrary number of arguments
38
+ are fixed before the tree map call.
39
+
40
+ We need this because the function that needs to be called
41
+ is kw only, but jax tree map does not support keyword only,
42
+ so we pass through this decorator
43
+
44
+ Example of usage:
45
+ ```
46
+ def f(*, a, b):
47
+ '''
48
+ kw only function that we would like to call in a jax.tree.map
49
+ but we do not know which arguments will be fixed before runtime
50
+ and jax.tree.map does not allow for kw
51
+ '''
52
+ # Do whatever you need to
53
+
54
+ # Then, by any means (at runtime), we determine the arguments that will be
55
+ # fixed during the jax.tree.map, we store them as a tuple of strings, as
56
+ # example:
57
+ fixed_args = ("a")
58
+ # and we actually fix them for all this to have some sense:
59
+ f = partial(f, a=observed_pinn_in[0])
60
+
61
+ # We also need to construct a tuple of tuples of dictionaries with only one
62
+ # key. We want to call the jax.tree.map for each tuple in the tuple. In
63
+ # each tuple we have dictionaries with one key (one dict for each for the
64
+ # argument of `f` which serves as key). The dictionaries will be handled by
65
+ # the decorator:
66
+ tree_map_args = ( # jax.tree.map over each element of this tuple
67
+ ({"a": None}, # most expected value since a is fixed
68
+ {"b": observed_values[0]}), # any useful value.
69
+
70
+ ({"a": None}, {"b":observed_values[1]),
71
+ ({"a": None}, {"b":observed_values[2]}),
72
+ )
73
+ # Then we can call
74
+ jax.tree.map(
75
+ _merge_dict_arguments(f, fixed_args),
76
+ tree_map_args,
77
+ is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)), # force iteration over outer tuple only
78
+ )
79
+ ```
80
+ In the code sample above, we see that the gain really lies in the fact that
81
+ `a=observed_pinn_in[0]` is not duplicated and this saves memory (in most of
82
+ runtime) while still enabling the jax.tree.map call (which however,
83
+ duplicates `a`
84
+ just for the computation time :)). Indeed, the direct way would be to
85
+ construct something like:
86
+ ```
87
+ tree_map_args = (
88
+ (observed_pinn_in[0], observed_values[0]),
89
+ (observed_pinn_in[0], observed_values[1]),
90
+ (observed_pinn_in[0], observed_values[2]),
91
+ )
92
+ ```
93
+ which could be a burden for the whole runtime
94
+ """
95
+
96
+ def wrapper(tuple_of_dict):
97
+ d = {}
98
+ # the for loop below is needed because there is no unpack operator
99
+ # authorized inside a comprehension for now: https://stackoverflow.com/a/37584733
100
+ for d_ in tuple_of_dict:
101
+ if len(d_.keys()) != 1:
102
+ raise ValueError("Problem here, we expect 1-key-dict")
103
+ if list(d_.keys())[0] not in fixed_args:
104
+ d.update(d_)
105
+ return fun(**d)
106
+
107
+ return wrapper
108
+
109
+
30
110
  class DGObservedParams(metaclass=DictToModuleMeta):
31
111
  """
32
112
  However, static type checkers cannot know that DGObservedParams inherit from
@@ -45,20 +125,22 @@ class DataGeneratorObservations(AbstractDataGenerator):
45
125
  ----------
46
126
  key : PRNGKeyArray
47
127
  Jax random key to shuffle batches
48
- obs_batch_size : int | None
128
+ obs_batch_size : tuple[int | None, ...]
49
129
  The size of the batch of randomly selected points among
50
130
  the `n` points. If None, no minibatch are used.
51
- observed_pinn_in : Float[Array, " n_obs nb_pinn_in"]
131
+ observed_pinn_in : Float[Array, " n_obs nb_pinn_in"] | tuple[Float[Array, " n_obs nb_pinn_in"], ...]
52
132
  Observed values corresponding to the input of the PINN
53
133
  (eg. the time at which we recorded the observations). The first
54
134
  dimension must corresponds to the number of observed_values.
55
135
  The second dimension depends on the input dimension of the PINN,
56
136
  that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
57
137
  for non-stationnary PDE.
58
- observed_values : Float[Array, " n_obs, nb_pinn_out"]
138
+ Can be a tuple of such arguments to support multidatasets, see below.
139
+ observed_values : Float[Array, " n_obs, nb_pinn_out"] | tuple[Float[Array, " n_obs, nb_pinn_out"], ...]
59
140
  Observed values that the PINN should learn to fit. The first
60
141
  dimension must be aligned with observed_pinn_in.
61
- observed_eq_params : dict[str, Float[Array, " n_obs 1"]], default={}
142
+ Can be a tuple of such arguments to support multidatasets, see below.
143
+ observed_eq_params : dict[str, Float[Array, " n_obs 1"]] | tuple[dict[str, Float[Array, " n_obs 1"], ...]], default=None
62
144
  A dict with keys corresponding to
63
145
  the parameter name. The keys must match the keys in
64
146
  `params["eq_params"]`, ie., if only some parameters are observed, other
@@ -66,6 +148,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
66
148
  with values corresponding to the parameter value for which we also
67
149
  have observed_pinn_in and observed_values. Hence the first
68
150
  dimension must be aligned with observed_pinn_in and observed_values.
151
+ Can be a tuple of such arguments to support multidatasets, see below.
69
152
  Optional argument.
70
153
  sharding_device : jax.sharding.Sharding, default=None
71
154
  Default None. An optional sharding object to constraint the storage
@@ -75,73 +158,303 @@ class DataGeneratorObservations(AbstractDataGenerator):
75
158
  can still be performed on other devices (*e.g.* GPU, TPU or
76
159
  any pre-defined Sharding) thanks to the `obs_batch_sharding`
77
160
  arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
161
+
162
+ ** New in jinns v1.7.1:** We provide the possibility of specifying several
163
+ datasets of observations, this can serve a variety of purposes, as for
164
+ example, provided different observations for different channels of the
165
+ solution (by using the `obs_slice` attribute of Loss objects, see the
166
+ notebook on this topic in the documentation). To
167
+ provide several datasets, it suffices to pass `observed_values` or
168
+ `observed_pinn_in` or `observed_eq_params` as tuples. If you have
169
+ several `observed_values` but the same `observed_eq_params` and or the
170
+ same `observed_pinn_in`, the two latter should not be duplicated in
171
+ tuples of the same length of `observed_values`. The last sentence
172
+ remains true when interchanging the terms `observed_values`,
173
+ `observed_pinn_in` and `observed_eq_params` in any position.
174
+ This is not a syntaxic sugar.
175
+ This is a real necessity to avoid duplicating data in the dataclasses
176
+ attributes, to gain speed and memory performances. This
177
+ internally handled with dynamic freezing of non vectorized
178
+ arguments (see code...). Note that the data are still duplicated at
179
+ the moment of the vectorial operation (jax tree map), and at during
180
+ that operation memory reaches its peaks and the memory usage is the
181
+ same in the DG with duplication and DG without duplication. However,
182
+ after that, the memory usage goes down again for the version without
183
+ duplication so we avoid memory overflow that could happen, for example,
184
+ when computing costly dynamic losses while storing a uselessly big
185
+ DataGeneratorObservations
186
+ (see the test script at
187
+ `jinns/tests/dataGenerator_tests/profiling_DataGeneratorObservations.py`
188
+ ).
189
+
78
190
  """
79
191
 
80
192
  key: PRNGKeyArray
81
- obs_batch_size: int | None = eqx.field(static=True)
82
- observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
83
- observed_values: Float[Array, " n_obs nb_pinn_out"]
84
- observed_eq_params: eqx.Module | None
85
- sharding_device: jax.sharding.Sharding | None # = eqx.field(static=True)
193
+ obs_batch_size: tuple[int | None, ...] = eqx.field(static=True)
194
+ observed_pinn_in: tuple[Float[Array, " n_obs nb_pinn_in"], ...]
195
+ observed_values: tuple[Float[Array, " n_obs nb_pinn_out"], ...]
196
+ observed_eq_params: tuple[eqx.Module | None, ...]
197
+ sharding_device: jax.sharding.Sharding | None = eqx.field(static=True)
86
198
 
87
- n: int = eqx.field(init=False, static=True)
88
- curr_idx: int = eqx.field(init=False)
89
- indices: Array = eqx.field(init=False)
199
+ n: tuple[int, ...] = eqx.field(init=False, static=True)
200
+ curr_idx: tuple[int, ...] = eqx.field(init=False)
201
+ indices: tuple[Array, ...] = eqx.field(init=False)
90
202
 
91
203
  def __init__(
92
204
  self,
93
205
  *,
94
206
  key: PRNGKeyArray,
95
- obs_batch_size: int | None = None,
96
- observed_pinn_in: Float[Array, " n_obs nb_pinn_in"],
97
- observed_values: Float[Array, " n_obs nb_pinn_out"],
98
- observed_eq_params: InputEqParams | None = None,
207
+ obs_batch_size: tuple[int | None, ...] | int | None = None,
208
+ observed_pinn_in: tuple[Float[Array, " n_obs nb_pinn_in"], ...]
209
+ | Float[Array, " n_obs nb_pinn_in"],
210
+ observed_values: tuple[Float[Array, " n_obs nb_pinn_out"], ...]
211
+ | Float[Array, " n_obs nb_pinn_out"],
212
+ observed_eq_params: tuple[InputEqParams, ...] | InputEqParams | None = None,
99
213
  sharding_device: jax.sharding.Sharding | None = None,
100
214
  ) -> None:
215
+ """ """
101
216
  super().__init__()
102
217
  self.key = key
103
- self.obs_batch_size = obs_batch_size
218
+
219
+ if not isinstance(observed_values, tuple):
220
+ observed_values = (observed_values,)
221
+ if not isinstance(observed_pinn_in, tuple):
222
+ observed_pinn_in = (observed_pinn_in,)
223
+ if observed_eq_params is not None:
224
+ if not isinstance(observed_eq_params, tuple):
225
+ observed_eq_params = (observed_eq_params,)
226
+ else:
227
+ observed_eq_params = (None,)
228
+
229
+ # now if values, pinn_in, and eq_params does not have same length (as
230
+ # tuples), we must find the longest one and the other either must be
231
+ # length 1 or must be the same length as the longest
232
+ len_longest_tuple = max(
233
+ map(len, (observed_values, observed_pinn_in, observed_eq_params))
234
+ )
235
+ longest_tuple = max(
236
+ (observed_values, observed_pinn_in, observed_eq_params), key=len
237
+ )
238
+ if len(observed_values) != 1 and len(observed_values) != len_longest_tuple:
239
+ raise ValueError(
240
+ "If observed_values is a tuple, it should"
241
+ " be of length 1 (one array, the same for"
242
+ " all the pinn_in and eq_params entries) or be of the same"
243
+ " length as the longest tuple of entries (1 to 1 matching)"
244
+ )
245
+ if len(observed_pinn_in) != 1 and len(observed_pinn_in) != len_longest_tuple:
246
+ raise ValueError(
247
+ "If observed_pinn_in is a tuple, it should"
248
+ " be of length 1 (one array, the same for"
249
+ " all the values and eq_params entries) or be of the same"
250
+ " length as the longest tuple of entries (1 to 1 matching)"
251
+ )
252
+ if (
253
+ len(observed_eq_params) != 1
254
+ and len(observed_eq_params) != len_longest_tuple
255
+ ):
256
+ raise ValueError(
257
+ "If observed_eq_params is a tuple, it should"
258
+ " be of length 1 (one array, the same for"
259
+ " all the values and pinn_in entries) or be of the same"
260
+ " length as the longest tuple of entries (1 to 1 matching)"
261
+ )
262
+
263
+ ### Start check first axis
264
+
265
+ def check_first_axis(*, values, pinn_in_array):
266
+ if values.shape[0] != pinn_in_array.shape[0]:
267
+ raise ValueError(
268
+ "Each matching elements of self.observed_pinn_in and self.observed_values must have same first axis"
269
+ )
270
+ return values
271
+
272
+ tree_map_args = tuple(
273
+ ({"values": v}, {"pinn_in_array": p})
274
+ for v, p in zip_longest(observed_values, observed_pinn_in)
275
+ )
276
+ fixed_args = ()
277
+ if len(observed_values) != len(observed_pinn_in):
278
+ if len(observed_pinn_in) == 1:
279
+ check_first_axis = partial(
280
+ check_first_axis, pinn_in_array=observed_pinn_in[0]
281
+ )
282
+ fixed_args = fixed_args + ("pinn_in_array",)
283
+ if len(observed_values) == 1:
284
+ check_first_axis = partial(check_first_axis, values=observed_values[0])
285
+ fixed_args = fixed_args + ("values",)
286
+ # ... and then we do the tree map. Note that in the tree.map below,
287
+ # self.observed_eq_params can have None leaves
288
+ # tree_map_args is a tuple of tuple dicts: 1) outer tuples are those we
289
+ # will vectorize over 2) inside tuples to be able to unpack
290
+ # dynamically (i.e. varying nb of elements to pass to fun)
291
+ # 3) then the dicts are merged to feed the kw only function
292
+ # tree.map cannot directly feed a kw only
293
+ # function such as check_first_axis (so we pass through
294
+ # the decorator)
295
+ jax.tree.map(
296
+ _merge_dict_arguments(check_first_axis, fixed_args),
297
+ tree_map_args,
298
+ is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
299
+ )
300
+
301
+ ### End check first axis
302
+
104
303
  self.observed_pinn_in = observed_pinn_in
105
304
  self.observed_values = observed_values
305
+ if observed_eq_params == (None,):
306
+ self.observed_eq_params = observed_eq_params # pyright: ignore
307
+ # (this is resolved later on one instanciating DGObservedParams)
308
+ else:
309
+ self.observed_eq_params = jax.tree.map(
310
+ lambda d: {
311
+ k: v[:, None] if len(v.shape) == 1 else v for k, v in d.items()
312
+ },
313
+ observed_eq_params,
314
+ is_leaf=lambda x: isinstance(x, dict),
315
+ )
106
316
 
107
- if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
108
- raise ValueError(
109
- "self.observed_pinn_in and self.observed_values must have same first axis"
317
+ self.observed_pinn_in = jax.tree.map(
318
+ lambda x: x[:, None] if len(x.shape) == 1 else x, self.observed_pinn_in
319
+ )
320
+
321
+ self.observed_values = jax.tree.map(
322
+ lambda x: x[:, None] if len(x.shape) == 1 else x, self.observed_values
323
+ )
324
+
325
+ ### Start check first axis 2
326
+ def check_first_axis2(*, eq_params_dict, pinn_in_array):
327
+ if eq_params_dict is not None:
328
+ for _, v in eq_params_dict.items():
329
+ if v.shape[0] != pinn_in_array.shape[0]:
330
+ raise ValueError(
331
+ "Each matching elements of self.observed_pinn_in and self.observed_eq_params must have the same first axis"
332
+ )
333
+
334
+ # the following tree_map_args will work if all lengths are equal either
335
+ # 1 or more
336
+ tree_map_args = tuple(
337
+ ({"eq_params_dict": e}, {"pinn_in_array": p})
338
+ for e, p in zip_longest(self.observed_eq_params, self.observed_pinn_in)
339
+ )
340
+ fixed_args = ()
341
+ if len(self.observed_eq_params) != len(self.observed_pinn_in):
342
+ if len(self.observed_pinn_in) == 1:
343
+ check_first_axis2 = partial(
344
+ check_first_axis2, pinn_in_array=self.observed_pinn_in[0]
345
+ )
346
+ fixed_args = fixed_args + ("pinn_in_array",)
347
+
348
+ if len(self.observed_eq_params) == 1:
349
+ check_first_axis2 = partial(
350
+ check_first_axis2, eq_params_dict=self.observed_eq_params[0]
351
+ )
352
+ fixed_args = fixed_args + ("eq_params_dict",)
353
+ jax.tree.map(
354
+ _merge_dict_arguments(
355
+ check_first_axis2, fixed_args
356
+ ), # https://stackoverflow.com/a/42421497
357
+ tree_map_args,
358
+ is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
359
+ )
360
+
361
+ ### End check first axis 2
362
+
363
+ ### Start check ndim
364
+
365
+ def check_ndim(*, values, pinn_in_array, eq_params_dict):
366
+ if values.ndim > 2:
367
+ raise ValueError(
368
+ "Each element of self.observed_pinn_in must have 2 dimensions"
369
+ )
370
+ if pinn_in_array.ndim > 2:
371
+ raise ValueError(
372
+ "Each element of self.observed_values must have 2 dimensions"
373
+ )
374
+ if eq_params_dict is not None:
375
+ for _, v in eq_params_dict.items():
376
+ if v.ndim > 2:
377
+ raise ValueError(
378
+ "Each value of observed_eq_params must have 2 dimensions"
379
+ )
380
+
381
+ # the following tree_map_args will work if all lengths are equal either
382
+ # 1 or more
383
+ tree_map_args = tuple(
384
+ ({"eq_params_dict": e}, {"pinn_in_array": p}, {"values": v})
385
+ for e, p, v in zip_longest(
386
+ self.observed_eq_params, self.observed_pinn_in, self.observed_values
110
387
  )
111
- if len(self.observed_pinn_in.shape) == 1:
112
- self.observed_pinn_in = self.observed_pinn_in[:, None]
113
- if self.observed_pinn_in.ndim > 2:
114
- raise ValueError("self.observed_pinn_in must have 2 dimensions")
115
- if len(self.observed_values.shape) == 1:
116
- self.observed_values = self.observed_values[:, None]
117
- if self.observed_values.ndim > 2:
118
- raise ValueError("self.observed_values must have 2 dimensions")
388
+ )
389
+ # now, if some shape are different, it can only be because there are 1
390
+ # while we expect a fixed n (thanks to the early tests above)
391
+ # then we must fix the arguments that are single leaf pytree
392
+ # and keep track of the arguments that are fixed to be able to remove
393
+ # them in the wrapper
394
+ fixed_args = ()
395
+ if len(self.observed_eq_params) != len(self.observed_pinn_in) or len(
396
+ self.observed_eq_params
397
+ ) != len(self.observed_values):
398
+ if len(self.observed_pinn_in) == 1:
399
+ check_ndim = partial(check_ndim, pinn_in_array=self.observed_pinn_in[0])
400
+ fixed_args = fixed_args + ("pinn_in_array",)
401
+ if len(self.observed_eq_params) == 1:
402
+ check_ndim = partial(
403
+ check_ndim, eq_params_dict=self.observed_eq_params[0]
404
+ )
405
+ fixed_args = fixed_args + ("eq_params_dict",)
406
+ if len(self.observed_values) == 1:
407
+ check_ndim = partial(check_ndim, values=self.observed_values[0])
408
+ fixed_args = fixed_args + ("values",)
119
409
 
120
- if observed_eq_params is not None:
121
- for _, v in observed_eq_params.items():
122
- if v.shape[0] != self.observed_pinn_in.shape[0]:
123
- raise ValueError(
124
- "self.observed_pinn_in and the values of"
125
- " self.observed_eq_params must have the same first axis"
126
- )
127
- for k, v in observed_eq_params.items():
128
- if len(v.shape) == 1:
129
- # Reshape to add an axis for 1-d Array
130
- observed_eq_params[k] = v[:, None]
131
- if len(v.shape) > 2:
132
- raise ValueError(
133
- f"Each key of observed_eq_params must have 2"
134
- f"dimensions, key {k} had shape {v.shape}."
135
- )
136
- # Convert the dict of observed parameters to the internal `EqParams`
137
- # class used by Jinns.
138
- self.observed_eq_params = DGObservedParams(
139
- observed_eq_params, "DGObservedParams"
410
+ jax.tree.map(
411
+ _merge_dict_arguments(check_ndim, fixed_args),
412
+ tree_map_args,
413
+ is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
414
+ )
415
+ ### End check ndim
416
+
417
+ # longest_tuple will be used for correct jax tree map broadcast. Note
418
+ # that even though self.observed_pinn_in and self.observed_values and
419
+ # self.observed_eq_params does
420
+ # not have the same len (as tuples), their components (jnp.arrays) do
421
+ # have the same first axis. This is worked out by all the previous
422
+ # checks
423
+ self.n = jax.tree.map(
424
+ lambda o: o.shape[0],
425
+ tuple(_ for _ in jax.tree.leaves(longest_tuple)), # jax.tree.leaves
426
+ # because if longest_tuple is eq_params then it is a dict but we do
427
+ # not want self.n to have the dict tree structure
428
+ )
429
+
430
+ if isinstance(obs_batch_size, int) or obs_batch_size is None:
431
+ self.obs_batch_size = jax.tree.map(
432
+ lambda _: obs_batch_size,
433
+ tuple(_ for _ in jax.tree.leaves(longest_tuple)), # jax tree leaves
434
+ # because if longest_tuple is eq_params then it is a dict but we do
435
+ # not want self.n to have the dict tree structure
140
436
  )
437
+ elif isinstance(obs_batch_size, tuple):
438
+ if len(obs_batch_size) != len_longest_tuple and len(obs_batch_size) != 1:
439
+ raise ValueError(
440
+ "If obs_batch_size is a tuple, it must me"
441
+ " of length 1 or of length equal to the"
442
+ " maximum length between values, pinn_in and"
443
+ " eq_params."
444
+ )
445
+ self.obs_batch_size = obs_batch_size
141
446
  else:
142
- self.observed_eq_params = observed_eq_params
447
+ raise ValueError("obs_batch_size must be an int, a tuple or None")
143
448
 
144
- self.n = self.observed_pinn_in.shape[0]
449
+ # After all the checks
450
+ # Convert the dict of observed parameters to the internal
451
+ # `DGObservedParams`
452
+ # class used by Jinns.
453
+ self.observed_eq_params = tuple(
454
+ DGObservedParams(o_, "DGObservedParams")
455
+ for o_ in self.observed_eq_params
456
+ if o_ is not None
457
+ )
145
458
 
146
459
  self.sharding_device = sharding_device
147
460
  if self.sharding_device is not None:
@@ -155,28 +468,39 @@ class DataGeneratorObservations(AbstractDataGenerator):
155
468
  self.observed_eq_params, self.sharding_device
156
469
  )
157
470
 
158
- if self.obs_batch_size is not None:
159
- self.curr_idx = self.n + self.obs_batch_size
160
- # to be sure there is a shuffling at first get_batch()
161
- else:
162
- self.curr_idx = 0
471
+ # When self.obs_batch_size leaf is None we will have self.curr_idx leaf
472
+ # to None. (Previous behaviour would put an unused self.curr_idx to 0)
473
+ self.curr_idx = jax.tree.map(
474
+ lambda bs, n: bs + n if bs is not None else None,
475
+ self.obs_batch_size,
476
+ self.n,
477
+ is_leaf=lambda x: x is None,
478
+ )
163
479
  # For speed and to avoid duplicating data what is really
164
480
  # shuffled is a vector of indices
481
+ self.indices = jax.tree.map(jnp.arange, self.n)
165
482
  if self.sharding_device is not None:
166
483
  self.indices = jax.lax.with_sharding_constraint(
167
- jnp.arange(self.n), self.sharding_device
484
+ self.indices, self.sharding_device
168
485
  )
169
- else:
170
- self.indices = jnp.arange(self.n)
171
486
 
172
- # recall post_init is the only place with _init_ where we can set
173
- # self attribute in a in-place way
174
- self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
175
- # the call to _reset_batch_idx_and_permute in legacy DG
487
+ if not isinstance(self.key, tuple):
488
+ # recall post_init is the only place with _init_ where we can set
489
+ # self attribute in a in-place way
490
+ self.key = jax.tree.unflatten(
491
+ jax.tree.structure(self.n),
492
+ jax.random.split(self.key, len(jax.tree.leaves(self.n))),
493
+ )
176
494
 
177
495
  def _get_operands(
178
496
  self,
179
- ) -> tuple[PRNGKeyArray, Int[Array, " n"], int, int | None, None]:
497
+ ) -> tuple[
498
+ tuple[PRNGKeyArray, ...],
499
+ tuple[Int[Array, " n"], ...],
500
+ tuple[int, ...],
501
+ tuple[int | None, ...],
502
+ None,
503
+ ]:
180
504
  return (
181
505
  self.key,
182
506
  self.indices,
@@ -185,59 +509,143 @@ class DataGeneratorObservations(AbstractDataGenerator):
185
509
  None,
186
510
  )
187
511
 
512
+ @staticmethod
188
513
  def obs_batch(
189
- self,
190
- ) -> tuple[Self, ObsBatchDict]:
514
+ *,
515
+ n,
516
+ obs_batch_size,
517
+ observed_pinn_in,
518
+ observed_values,
519
+ observed_eq_params,
520
+ curr_idx,
521
+ key,
522
+ indices,
523
+ ) -> tuple[PRNGKeyArray, Array, Int, ObsBatchDict]:
191
524
  """
192
525
  Return an update DataGeneratorObservations instance and an ObsBatchDict
193
526
  """
194
- if self.obs_batch_size is None or self.obs_batch_size == self.n:
527
+ if obs_batch_size is None or obs_batch_size == n:
195
528
  # Avoid unnecessary reshuffling
196
- return self, ObsBatchDict(
197
- {
198
- "pinn_in": self.observed_pinn_in,
199
- "val": self.observed_values,
200
- "eq_params": self.observed_eq_params,
201
- }
529
+ return (
530
+ key,
531
+ indices,
532
+ curr_idx,
533
+ ObsBatchDict(
534
+ {
535
+ "pinn_in": observed_pinn_in,
536
+ "val": observed_values,
537
+ "eq_params": observed_eq_params,
538
+ }
539
+ ),
202
540
  )
203
541
 
204
- new_attributes = _reset_or_increment(
205
- self.curr_idx + self.obs_batch_size,
206
- self.n,
207
- self._get_operands(), # type: ignore
542
+ new_key, new_indices, new_curr_idx = _reset_or_increment(
543
+ curr_idx + obs_batch_size,
544
+ n,
545
+ (key, indices, curr_idx, obs_batch_size, None), # type: ignore
208
546
  # ignore since the case self.obs_batch_size is None has been
209
547
  # handled above
210
548
  )
211
- new = eqx.tree_at(
212
- lambda m: (m.key, m.indices, m.curr_idx), # type: ignore
213
- self,
214
- new_attributes,
215
- )
216
549
 
217
550
  minib_indices = jax.lax.dynamic_slice(
218
- new.indices,
219
- start_indices=(new.curr_idx,),
220
- slice_sizes=(new.obs_batch_size,),
551
+ new_indices,
552
+ start_indices=(new_curr_idx,),
553
+ slice_sizes=(obs_batch_size,),
221
554
  )
222
555
 
223
556
  obs_batch: ObsBatchDict = {
224
557
  "pinn_in": jnp.take(
225
- new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
558
+ observed_pinn_in, minib_indices, unique_indices=True, axis=0
226
559
  ),
227
560
  "val": jnp.take(
228
- new.observed_values, minib_indices, unique_indices=True, axis=0
561
+ observed_values, minib_indices, unique_indices=True, axis=0
229
562
  ),
230
563
  "eq_params": jax.tree_util.tree_map(
231
564
  lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0), # type: ignore
232
- new.observed_eq_params,
565
+ observed_eq_params,
233
566
  ),
234
567
  }
235
- return new, obs_batch
568
+ return new_key, new_indices, new_curr_idx, obs_batch
236
569
 
237
570
  def get_batch(
238
571
  self,
239
- ) -> tuple[Self, ObsBatchDict]:
572
+ ) -> tuple[Self, tuple[ObsBatchDict, ...]]:
240
573
  """
241
574
  Generic method to return a batch
242
575
  """
243
- return self.obs_batch()
576
+ # the following tree map over DataGeneratorObservations.obs_batch, must
577
+ # be handled with pre-fixed arguments when, for memory reason,
578
+ # observed_pinn_in or observed_values or observed_eq_params have not
579
+ # does not have the same length. If all tuples are of size 1, this
580
+ # should work totally transparently
581
+ args = (
582
+ self.observed_eq_params,
583
+ self.observed_pinn_in,
584
+ self.observed_values,
585
+ self.n,
586
+ self.obs_batch_size,
587
+ self.curr_idx,
588
+ self.key,
589
+ self.indices,
590
+ )
591
+
592
+ tree_map_args = tuple(
593
+ (
594
+ {"observed_eq_params": e},
595
+ {"observed_pinn_in": p},
596
+ {"observed_values": v},
597
+ {"n": n},
598
+ {"obs_batch_size": b},
599
+ {"curr_idx": c},
600
+ {"key": k},
601
+ {"indices": i},
602
+ )
603
+ for e, p, v, n, b, c, k, i in zip_longest(*args)
604
+ )
605
+ fixed_args = ()
606
+ obs_batch_fun = DataGeneratorObservations.obs_batch
607
+ if len(set(map(len, args))) > 1: # at least 2 lengths differ
608
+ # but since values, pinn_in and equations are the arguments that
609
+ # generates all the others, it suffices to potentially fix the
610
+ # former
611
+ if len(self.observed_pinn_in) == 1:
612
+ obs_batch_fun = partial(
613
+ obs_batch_fun, observed_pinn_in=self.observed_pinn_in[0]
614
+ )
615
+ fixed_args = fixed_args + ("observed_pinn_in",)
616
+ if len(self.observed_eq_params) == 1:
617
+ obs_batch_fun = partial(
618
+ obs_batch_fun, observed_eq_params=self.observed_eq_params[0]
619
+ )
620
+ fixed_args = fixed_args + ("observed_eq_params",)
621
+ if len(self.observed_values) == 1:
622
+ obs_batch_fun = partial(
623
+ obs_batch_fun, observed_values=self.observed_values[0]
624
+ )
625
+ fixed_args = fixed_args + ("observed_values",)
626
+
627
+ ret = jax.tree.map(
628
+ _merge_dict_arguments(obs_batch_fun, fixed_args),
629
+ tree_map_args,
630
+ is_leaf=lambda x: (isinstance(x, tuple) and isinstance(x[0], dict)),
631
+ )
632
+ new_key = jax.tree.map(
633
+ lambda l: l[0], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
634
+ ) # we must not traverse the second level
635
+ new_indices = jax.tree.map(
636
+ lambda l: l[1], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
637
+ )
638
+ new_curr_idx = jax.tree.map(
639
+ lambda l: l[2], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
640
+ )
641
+ obs_batch_tuple = jax.tree.map(
642
+ lambda l: l[3], ret, is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4
643
+ )
644
+
645
+ new = eqx.tree_at(
646
+ lambda m: (m.key, m.indices, m.curr_idx),
647
+ self,
648
+ (new_key, new_indices, new_curr_idx),
649
+ )
650
+
651
+ return new, obs_batch_tuple