returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +130 -42
  5. returnn/datasets/meta.py +93 -43
  6. returnn/datasets/postprocessing.py +597 -108
  7. returnn/datasets/util/vocabulary.py +90 -0
  8. returnn/frontend/__init__.py +1 -0
  9. returnn/frontend/_backend.py +41 -0
  10. returnn/frontend/_native/__init__.py +22 -0
  11. returnn/frontend/_numpy_backend.py +7 -0
  12. returnn/frontend/_utils.py +1 -1
  13. returnn/frontend/array_.py +48 -2
  14. returnn/frontend/assert_.py +35 -0
  15. returnn/frontend/attention.py +54 -20
  16. returnn/frontend/conv.py +273 -54
  17. returnn/frontend/device.py +14 -1
  18. returnn/frontend/encoder/conformer.py +20 -0
  19. returnn/frontend/encoder/transformer.py +2 -0
  20. returnn/frontend/loss.py +222 -3
  21. returnn/frontend/math_.py +54 -14
  22. returnn/native_op.cpp +182 -172
  23. returnn/native_op.py +36 -31
  24. returnn/sprint/cache.py +12 -13
  25. returnn/tensor/_dim_extra.py +7 -7
  26. returnn/tensor/_tensor_extra.py +10 -10
  27. returnn/tensor/utils.py +8 -5
  28. returnn/tf/frontend_layers/_backend.py +7 -3
  29. returnn/tf/layers/basic.py +27 -40
  30. returnn/tf/native_op.py +27 -63
  31. returnn/tf/network.py +1 -1
  32. returnn/tf/util/basic.py +22 -197
  33. returnn/torch/engine.py +157 -6
  34. returnn/torch/frontend/_backend.py +280 -29
  35. returnn/torch/frontend/bridge.py +61 -0
  36. returnn/torch/frontend/compile_helper.py +106 -0
  37. returnn/torch/util/array_.py +30 -0
  38. returnn/torch/util/assert_.py +122 -0
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/torch/util/native_op.py +885 -0
  41. returnn/torch/util/native_op_code_compiler.py +308 -0
  42. returnn/util/basic.py +6 -7
  43. returnn/util/better_exchook.py +4 -0
  44. returnn/util/cuda_env.py +332 -0
  45. returnn/util/debug.py +12 -2
  46. returnn/util/file_cache.py +15 -1
  47. returnn/util/fsa.py +17 -13
  48. returnn/util/native_code_compiler.py +104 -47
  49. returnn/util/task_system.py +1 -1
  50. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
  51. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
  52. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  53. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  54. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
@@ -4,20 +4,32 @@ Provides :class:`PostprocessingDataset`.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ from collections import deque
7
8
  from itertools import islice
8
9
  import numpy
9
10
  from numpy.random import RandomState
10
- from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
11
+ import select
12
+ import sys
13
+ import threading
14
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, TypeVar
11
15
 
16
+ from returnn.config import SubProcCopyGlobalConfigPreInitFunc
12
17
  from returnn.datasets.basic import DatasetSeq
13
18
  from returnn.datasets.util.strings import str_to_numpy_array
14
19
  from returnn.datasets.util.vocabulary import Vocabulary
15
20
  from returnn.tensor import Tensor, TensorDict
16
21
  from returnn.tensor.dim import Dim
17
- from returnn.util import basic as util
18
- from .basic import init_dataset
22
+ from returnn.util import basic as util, better_exchook
23
+ from returnn.util.multi_proc_non_daemonic_spawn import NonDaemonicSpawnContext
24
+ from .basic import Dataset, init_dataset
19
25
  from .cached2 import CachedDataset2
20
26
 
27
+ # noinspection PyProtectedMember
28
+ from multiprocessing.connection import Connection as mpConnection
29
+
30
+ _mp = NonDaemonicSpawnContext(process_pre_init_func=SubProcCopyGlobalConfigPreInitFunc())
31
+
32
+
21
33
  __all__ = ["PostprocessingDataset", "LaplaceOrdering", "Sequential"]
22
34
 
23
35
 
@@ -31,8 +43,15 @@ class PostprocessingDataset(CachedDataset2):
31
43
  SpecAugment or speed perturbation into the data loading pipeline.
32
44
 
33
45
  The integration into the data loading pipeline makes it easy to distribute the
34
- data processing work across multiple CPU cores using `MultiProcDataset` and in
35
- turn frees the GPU from data preprocessing tasks.
46
+ data processing work across multiple CPU cores and in turn frees the GPU from
47
+ data preprocessing tasks.
48
+
49
+ Multiprocessing can either be done using :class:``MultiProcDataset`` or by setting
50
+ `num_workers > 0` on this class.
51
+
52
+ The latter only applies parallelism to the post-processing functions themselves,
53
+ and does not duplicate the underlying dataset once per worker.
54
+ This is often fast enough and has the advantage of lower memory consumption.
36
55
 
37
56
  Example usage::
38
57
 
@@ -61,8 +80,8 @@ class PostprocessingDataset(CachedDataset2):
61
80
  The postprocessor functions operate on ``TensorDict``s, which have entries for
62
81
  all data keys in the underlying dataset.
63
82
 
64
- There may also be additional "meta" entries in the tensor dicts, like ``complete_frac``
65
- and ``seq_tag``.
83
+ There may also be additional "meta" entries in the tensor dicts, like ``complete_frac``,
84
+ ``seq_idx`` and ``seq_tag``.
66
85
  These should be copied over in a manner that is reasonable for the use case at hand and
67
86
  ensures forwards compatibility as well as reasonably possible.
68
87
 
@@ -93,11 +112,14 @@ class PostprocessingDataset(CachedDataset2):
93
112
 
94
113
  def __init__(
95
114
  self,
115
+ *,
96
116
  dataset: Dict[str, Any],
97
117
  map_seq: Optional[Callable] = None,
98
118
  map_seq_stream: Optional[Callable] = None,
99
119
  map_outputs: Optional[Dict[str, Any]] = None,
100
120
  map_seq_stream_preserves_num_seqs: Optional[bool] = None,
121
+ buf_size: int = 1,
122
+ num_workers: int = 0,
101
123
  **kwargs,
102
124
  ):
103
125
  """
@@ -123,6 +145,11 @@ class PostprocessingDataset(CachedDataset2):
123
145
  Example: `map_outputs={"data": {"dim": 42}}`
124
146
  :param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of
125
147
  sequences, i.e. for every input sequence there is exactly one output sequence.
148
+ :param buf_size: Buffer size for each worker, number of seqs to prefetch. Must be > 0.
149
+ :param num_workers: If > 0, configures the number of worker processes to use for data postprocessing.
150
+ Only the postprocessing is distributed across subprocesses,
151
+ the underlying dataset is only instantiated once.
152
+ This usually has lower memory consumption than using :class:``MultiProcDataset``.
126
153
  :param kwargs: see :class:`CachedDataset2`, :class:`Dataset`
127
154
  """
128
155
  super().__init__(**kwargs)
@@ -136,6 +163,11 @@ class PostprocessingDataset(CachedDataset2):
136
163
  if map_seq and map_seq_stream_preserves_num_seqs is not None:
137
164
  raise ValueError(f"{self}: map_seq_stream_preserves_num_seqs is only allowed with map_seq_stream")
138
165
 
166
+ if buf_size < 1:
167
+ raise ValueError(f"{self}: buf_size must be > 0, but got {buf_size}")
168
+ if num_workers < 0:
169
+ raise ValueError(f"{self}: num_workers must be >= 0, but got {num_workers}")
170
+
139
171
  self._dataset_def = dataset
140
172
  self._map_seq = map_seq
141
173
  self._map_seq_stream = map_seq_stream
@@ -144,7 +176,6 @@ class PostprocessingDataset(CachedDataset2):
144
176
  assert map_seq_stream_preserves_num_seqs is None or isinstance(map_seq_stream_preserves_num_seqs, bool)
145
177
  self._map_seq_stream_preserves_num_seqs = map_seq_stream_preserves_num_seqs
146
178
  self._map_outputs = map_outputs
147
- self._rng = RandomState(self._get_random_seed_for_epoch(0))
148
179
  self._seq_list_for_validation: Optional[List[str]] = None
149
180
 
150
181
  self._dataset = init_dataset(self._dataset_def, parent_dataset=self)
@@ -154,6 +185,14 @@ class PostprocessingDataset(CachedDataset2):
154
185
  self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None
155
186
  self._data_iter_produced_num_seqs = 0
156
187
 
188
+ self._buf_size = buf_size
189
+ # Ensure only one feeder thread at a time accesses the wrapped dataset to
190
+ # prevent race conditions while moving from one epoch to the next.
191
+ self._dataset_lock = threading.Lock()
192
+ self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup
193
+ self._num_workers = num_workers
194
+ self._worker_procs: Optional[List[_WorkerProcParent]] = None
195
+
157
196
  self._in_tensor_dict_template = TensorDict(
158
197
  {name: self._make_tensor_template_from_input(name) for name in self._dataset.get_data_keys()}
159
198
  )
@@ -166,7 +205,11 @@ class PostprocessingDataset(CachedDataset2):
166
205
  self.labels = self._dataset.labels.copy()
167
206
  # update only after _out_tensor_dict_template has been created from _in_tensor_dict_template
168
207
  self._in_tensor_dict_template.update(
169
- {"complete_frac": {"dims": (), "dtype": "float32"}, "seq_tag": {"dims": (), "dtype": "string"}},
208
+ {
209
+ "complete_frac": {"dims": (), "dtype": "float32"},
210
+ "seq_idx": {"dims": (), "dtype": "int32"},
211
+ "seq_tag": {"dims": (), "dtype": "string"},
212
+ },
170
213
  auto_convert=True,
171
214
  )
172
215
  self.num_outputs = {
@@ -201,14 +244,41 @@ class PostprocessingDataset(CachedDataset2):
201
244
  if seq_order is not None:
202
245
  raise ValueError("map_seq_stream is set, cannot specify custom seq_order")
203
246
 
247
+ if self._multi_proc_data_iter is not None:
248
+ self._multi_proc_data_iter.stop()
249
+ self._multi_proc_data_iter = None
250
+
204
251
  if epoch is None and seq_list is None and seq_order is None:
205
252
  self._num_seqs = 0
206
253
  return True
207
254
 
208
- self._rng = RandomState(self._get_random_seed_for_epoch(epoch=epoch))
209
- assert self._dataset is not None
210
- self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
211
- self._data_iter = enumerate(self._build_mapping_iter())
255
+ if self._num_workers > 0:
256
+ self._lazy_init_worker_procs()
257
+ assert self._worker_procs is not None and len(self._worker_procs) == self._num_workers
258
+ parent_conns, child_conns = zip(*[_mp.Pipe() for _ in range(self._num_workers)])
259
+ base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 683859 * self._num_workers
260
+ for i, (worker, child_conn) in enumerate(zip(self._worker_procs, child_conns)):
261
+ worker.init_seq_order(
262
+ epoch=epoch,
263
+ rng_seed=(base_rng_seed + 30411 * i) % (2**32 - 1),
264
+ seq_list=seq_list,
265
+ seq_pipe=child_conn,
266
+ )
267
+ data_iter = self._multi_proc_data_iter = self._init_multi_proc_data_iter(
268
+ epoch=epoch, feeder_to_worker_conns=parent_conns, seq_list=seq_list, seq_order=seq_order
269
+ )
270
+ else:
271
+ self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
272
+ data_iter = _build_mapping_iter(
273
+ _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template),
274
+ map_seq=self._map_seq,
275
+ map_seq_stream=self._map_seq_stream,
276
+ epoch=epoch,
277
+ out_tensor_dict_template=self._out_tensor_dict_template,
278
+ rng=RandomState(self._get_random_seed_for_epoch(epoch=epoch)),
279
+ seq_list_for_validation=seq_list,
280
+ )
281
+ self._data_iter = enumerate(data_iter)
212
282
  self._data_iter_produced_num_seqs = 0
213
283
  self._seq_list_for_validation = seq_list
214
284
  if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True:
@@ -220,6 +290,24 @@ class PostprocessingDataset(CachedDataset2):
220
290
  pass # some datasets don't know their num_seqs
221
291
  return True
222
292
 
293
+ def __del__(self):
294
+ if self._multi_proc_data_iter is not None:
295
+ self._multi_proc_data_iter.stop(join=True)
296
+ self._multi_proc_data_iter = None
297
+ if not self._worker_procs:
298
+ return
299
+ got_exception = False
300
+ for parent in self._worker_procs:
301
+ # noinspection PyBroadException
302
+ try:
303
+ parent.exit(join=False)
304
+ except Exception:
305
+ got_exception = True
306
+ if got_exception:
307
+ return
308
+ for parent in self._worker_procs:
309
+ util.try_run(parent.worker_proc.join)
310
+
223
311
  def get_current_seq_order(self):
224
312
  """:return: current seq order of wrapped dataset, if map_seq_stream is not used"""
225
313
  if self._map_seq_stream is not None:
@@ -256,6 +344,19 @@ class PostprocessingDataset(CachedDataset2):
256
344
  assert self._dataset is not None
257
345
  return self._dataset.supports_sharding()
258
346
 
347
+ def finish_epoch(self, *, free_resources=False):
348
+ """finish_epoch"""
349
+ super().finish_epoch(free_resources=free_resources)
350
+ if not free_resources:
351
+ return
352
+ if self._multi_proc_data_iter is not None:
353
+ self._multi_proc_data_iter.stop(join=True)
354
+ self._multi_proc_data_iter = None
355
+ if self._worker_procs is not None:
356
+ for wp in self._worker_procs:
357
+ wp.exit(join=True)
358
+ self._worker_procs = None
359
+
259
360
  def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
260
361
  while True:
261
362
  try:
@@ -286,101 +387,6 @@ class PostprocessingDataset(CachedDataset2):
286
387
  seq = DatasetSeq(complete_frac=complete_frac, features=features, seq_idx=seq_idx, seq_tag=seq_tag)
287
388
  return seq
288
389
 
289
- def _build_mapping_iter(self) -> Iterator[TensorDict]:
290
- """
291
- :return: an iterator applying both the segment level and across-segment transformations on the given dataset
292
- """
293
-
294
- def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]:
295
- last_complete_frac = 0.0
296
- for t_dict in inner:
297
- assert isinstance(t_dict, TensorDict), (
298
- f"postprocessing mapper function must produce a {TensorDict.__name__}, "
299
- f"but got a {type(t_dict).__name__}"
300
- )
301
- if "complete_frac" in t_dict.data: # sanity check complete_frac
302
- complete_frac = float(t_dict.data["complete_frac"].raw_tensor)
303
- assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
304
- assert complete_frac >= last_complete_frac, (
305
- "complete_frac must be monotonically increasing, "
306
- f"but got {complete_frac} after {last_complete_frac}"
307
- )
308
- last_complete_frac = complete_frac
309
- for data_key, out_t in self._out_tensor_dict_template.data.items():
310
- in_t = t_dict.data[data_key]
311
- assert in_t.ndim == out_t.batch_ndim, (
312
- f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. "
313
- "Postprocessing data tensors must not have a batch dimension."
314
- )
315
- assert in_t.dtype == out_t.dtype, (
316
- f"dtype mismatch for {data_key}: '{in_t.dtype}' != '{out_t.dtype}'"
317
- )
318
- for i, (in_dim, out_shape) in enumerate(zip(in_t.dims, out_t.shape)):
319
- assert in_dim.dimension is None or in_dim.dimension == out_shape, (
320
- f"Dim {i} mismatch on {data_key}: "
321
- f"{in_dim.dimension} must either be `None` or equal {out_shape}"
322
- )
323
- yield t_dict
324
-
325
- data_iter = self._iterate_dataset()
326
- if self._map_seq_stream is not None:
327
- data_iter = self._map_seq_stream(data_iter, epoch=self.epoch, rng=self._rng, **util.get_fwd_compat_kwargs())
328
- assert isinstance(data_iter, Iterator), (
329
- f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}"
330
- )
331
- return _validate_tensor_dict_iter(data_iter)
332
-
333
- def _iterate_dataset(self) -> Iterator[TensorDict]:
334
- """
335
- :return: generator providing data samples in the form of a TensorDict
336
- """
337
- data_keys = self._dataset.get_data_keys()
338
-
339
- seq_index = 0
340
- while self._dataset.is_less_than_num_seqs(seq_index):
341
- self._dataset.load_seqs(seq_index, seq_index + 1)
342
-
343
- tensor_dict = self._in_tensor_dict_template.copy_template()
344
- for data_key in data_keys:
345
- tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key)
346
-
347
- complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True)
348
- comp_frac_raw_tensor = None
349
- if complete_frac is not None:
350
- comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32)
351
- tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor
352
- seq_tag_raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
353
- tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor
354
-
355
- if self._map_seq is not None:
356
- tensor_dict = self._map_seq(
357
- tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs()
358
- )
359
- assert isinstance(tensor_dict, TensorDict), (
360
- f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
361
- )
362
-
363
- # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped
364
- # since we don't add/drop any segments w/ the non-iterator postprocessing function.
365
- if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None:
366
- tensor_dict.data["complete_frac"] = Tensor(
367
- "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor
368
- )
369
- if "seq_tag" not in tensor_dict.data:
370
- tensor_dict.data["seq_tag"] = Tensor(
371
- "seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor
372
- )
373
-
374
- if self._seq_list_for_validation is not None:
375
- seq_tag = self._seq_list_for_validation[seq_index]
376
- tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item()
377
- assert tag_of_seq == seq_tag, (
378
- f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given"
379
- )
380
-
381
- yield tensor_dict
382
- seq_index += 1
383
-
384
390
  def _make_tensor_template_from_input(self, data_key: str) -> Tensor:
385
391
  dtype = self._dataset.get_data_dtype(data_key)
386
392
  if dtype == "string":
@@ -399,6 +405,489 @@ class PostprocessingDataset(CachedDataset2):
399
405
  sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key])
400
406
  return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim)
401
407
 
408
+ def _lazy_init_worker_procs(self):
409
+ if self._worker_procs is not None:
410
+ return
411
+ self._worker_procs = [
412
+ _WorkerProcParent(
413
+ name=f"{self.__class__.__name__} {self.name} worker",
414
+ buffer_size=self._buf_size,
415
+ index=i,
416
+ map_seq=self._map_seq,
417
+ map_seq_stream=self._map_seq_stream,
418
+ out_tensor_dict_template=self._out_tensor_dict_template,
419
+ )
420
+ for i in range(self._num_workers)
421
+ ]
422
+
423
+ def _init_multi_proc_data_iter(
424
+ self,
425
+ *,
426
+ epoch: int,
427
+ feeder_to_worker_conns: Sequence[mpConnection],
428
+ seq_list: Optional[List[str]] = None,
429
+ seq_order: Optional[List[int]] = None,
430
+ ) -> _MultiProcDataIter:
431
+ assert len(feeder_to_worker_conns) == self._num_workers
432
+
433
+ quit_event = threading.Event()
434
+ dataset_thread = threading.Thread(
435
+ target=self._init_seq_order_and_distribute_seqs_to_children,
436
+ kwargs={
437
+ "epoch": epoch,
438
+ "quit_event": quit_event,
439
+ "seq_list": seq_list,
440
+ "seq_order": seq_order,
441
+ "worker_conns": feeder_to_worker_conns,
442
+ },
443
+ name=f"{self.__class__.__name__} feeder ep {epoch}",
444
+ )
445
+ # parent_conns are not closed here, because they move to a different thread, not process,
446
+ # and so they must remain open.
447
+ dataset_thread.start()
448
+ data_iter = _MultiProcDataIter(
449
+ dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=self._worker_procs
450
+ )
451
+ return data_iter
452
+
453
+ def _init_seq_order_and_distribute_seqs_to_children(
454
+ self,
455
+ *,
456
+ epoch: int,
457
+ quit_event: threading.Event,
458
+ seq_list: Optional[List[str]] = None,
459
+ seq_order: Optional[List[int]] = None,
460
+ worker_conns: Sequence[mpConnection],
461
+ ):
462
+ """
463
+ Initialize the wrapped dataset and distribute the contained sequences to the child worker processes.
464
+ """
465
+
466
+ assert self._buf_size > 0
467
+ assert len(worker_conns) > 0
468
+ assert self._num_workers > 0
469
+
470
+ caches: List[deque[TensorDict]] = [deque() for _ in range(len(worker_conns))]
471
+
472
+ def _any_conn_ready() -> bool:
473
+ ready, _, _ = select.select(worker_conns, [], [], 0)
474
+ return len(ready) > 0
475
+
476
+ def _maybe_distrib_seq(*, timeout=0.1):
477
+ assert timeout >= 0.0
478
+ # do not block indefinitely to periodically check the quit_event
479
+ ready_conns, _, _ = select.select(worker_conns, [], [], timeout)
480
+ assert len(worker_conns) == len(caches)
481
+ for child_queue, cache in zip(worker_conns, caches):
482
+ if child_queue not in ready_conns:
483
+ continue
484
+ msg, _ = child_queue.recv()
485
+ assert msg == "get_seq"
486
+ tensor_dict = cache.popleft() if len(cache) > 0 else None
487
+ child_queue.send(("seq", tensor_dict))
488
+
489
+ # Lock ensures that only one thread at a time accesses the wrapped dataset.
490
+ # This protects against issues while moving from one epoch to the next.
491
+ with self._dataset_lock:
492
+ self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
493
+ data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template)
494
+ data_iter = enumerate(data_iter)
495
+
496
+ def _add_to_cache() -> bool:
497
+ try:
498
+ idx, tensor_dict = next(data_iter)
499
+ caches[idx % len(caches)].append(tensor_dict)
500
+ return True
501
+ except StopIteration:
502
+ return False
503
+
504
+ while not quit_event.is_set():
505
+ # fetch seqs until all caches have at least one seq,
506
+ # if no child is waiting for seqs also fill until buf_size
507
+ while any(len(cache) == 0 for cache in caches) or (
508
+ sum(len(cache) for cache in caches) < self._buf_size and not _any_conn_ready()
509
+ ):
510
+ if not _add_to_cache():
511
+ break
512
+ if all(len(c) == 0 for c in caches):
513
+ break
514
+ try:
515
+ _maybe_distrib_seq()
516
+ except (BrokenPipeError, EOFError):
517
+ # queue is closed, i.e. the worker process crashed for some reason -> stop
518
+ break
519
+
520
+ for queue in worker_conns:
521
+ try:
522
+ queue.send(("seq", None))
523
+ except (BrokenPipeError, EOFError):
524
+ # queue is already closed, i.e. the worker process died
525
+ pass
526
+ finally:
527
+ queue.close()
528
+
529
+
530
+ def _iterate_dataset(dataset: Dataset, *, in_tensor_dict_template: TensorDict) -> Iterator[TensorDict]:
531
+ """
532
+ :return: generator providing data samples in the form of a TensorDict
533
+ """
534
+ data_keys = dataset.get_data_keys()
535
+
536
+ seq_index = 0
537
+ while dataset.is_less_than_num_seqs(seq_index):
538
+ dataset.load_seqs(seq_index, seq_index + 1)
539
+
540
+ tensor_dict = in_tensor_dict_template.copy_template()
541
+ for data_key in data_keys:
542
+ tensor_dict.data[data_key].raw_tensor = dataset.get_data(seq_index, data_key)
543
+
544
+ complete_frac = dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True)
545
+ if complete_frac is not None:
546
+ comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32)
547
+ tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor
548
+ seq_idx_raw_tensor = numpy.array(seq_index, dtype=numpy.int32)
549
+ tensor_dict.data["seq_idx"].raw_tensor = seq_idx_raw_tensor
550
+ seq_tag_raw_tensor = str_to_numpy_array(dataset.get_tag(seq_index))
551
+ tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor
552
+
553
+ yield tensor_dict
554
+ seq_index += 1
555
+
556
+
557
+ def _build_mapping_iter(
558
+ data_iter: Iterator[TensorDict],
559
+ *,
560
+ map_seq: Optional[Callable] = None,
561
+ map_seq_stream: Optional[Callable] = None,
562
+ epoch: int,
563
+ out_tensor_dict_template: TensorDict,
564
+ rng: RandomState,
565
+ seq_list_for_validation: Optional[List[str]] = None,
566
+ ) -> Iterator[TensorDict]:
567
+ """
568
+ Build an iterator applying the mapping functions on the given dataset iterator.
569
+
570
+ :param data_iter: iterator providing data samples in the form of a TensorDict
571
+ :param map_seq: see :class:`PostprocessingDataset`
572
+ :param map_seq_stream: see :class:`PostprocessingDataset`
573
+ :param epoch: current epoch number
574
+ :param out_tensor_dict_template: template for the output TensorDicts, used for validation
575
+ :param rng: random number generator to use
576
+ :param seq_list_for_validation: optional list of seq tags to validate against when processing the data
577
+ :return: an iterator applying both the segment level and across-segment transformations on the given dataset
578
+ """
579
+
580
+ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]:
581
+ last_complete_frac = 0.0
582
+ for t_dict in inner:
583
+ assert isinstance(t_dict, TensorDict), (
584
+ f"postprocessing mapper function must produce a {TensorDict.__name__}, "
585
+ f"but got a {type(t_dict).__name__}"
586
+ )
587
+ if "complete_frac" in t_dict.data: # sanity check complete_frac
588
+ complete_frac = float(t_dict.data["complete_frac"].raw_tensor)
589
+ assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
590
+ assert complete_frac >= last_complete_frac, (
591
+ "complete_frac must be monotonically increasing, "
592
+ f"but got {complete_frac} after {last_complete_frac}"
593
+ )
594
+ last_complete_frac = complete_frac
595
+ for data_key, out_t in out_tensor_dict_template.data.items():
596
+ in_t = t_dict.data[data_key]
597
+ assert in_t.ndim == out_t.batch_ndim, (
598
+ f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. "
599
+ "Postprocessing data tensors must not have a batch dimension."
600
+ )
601
+ assert in_t.dtype == out_t.dtype, f"dtype mismatch for {data_key}: '{in_t.dtype}' != '{out_t.dtype}'"
602
+ for i, (in_dim, out_shape) in enumerate(zip(in_t.dims, out_t.shape)):
603
+ assert in_dim.dimension is None or in_dim.dimension == out_shape, (
604
+ f"Dim {i} mismatch on {data_key}: {in_dim.dimension} must either be `None` or equal {out_shape}"
605
+ )
606
+ yield t_dict
607
+
608
+ def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict:
609
+ comp_frac_raw_tensor = (
610
+ tensor_dict.data["complete_frac"].raw_tensor if "complete_frac" in tensor_dict.data else None
611
+ )
612
+ seq_index_raw = tensor_dict.data["seq_idx"].raw_tensor
613
+ seq_index = int(seq_index_raw.item())
614
+ seq_tag_raw_tensor = tensor_dict.data["seq_tag"].raw_tensor
615
+
616
+ tensor_dict = map_seq(tensor_dict, epoch=epoch, seq_idx=seq_index, rng=rng, **util.get_fwd_compat_kwargs())
617
+ assert isinstance(tensor_dict, TensorDict), (
618
+ f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
619
+ )
620
+
621
+ # Re-adding the complete_frac/seq_idx/seq_tag here causes no harm in case they are dropped
622
+ # since we don't add/drop any segments w/ the non-iterator postprocessing function.
623
+ if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None:
624
+ tensor_dict.data["complete_frac"] = Tensor(
625
+ "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor
626
+ )
627
+ if "seq_idx" not in tensor_dict.data:
628
+ tensor_dict.data["seq_idx"] = Tensor("seq_idx", dims=(), dtype="int32", raw_tensor=seq_index_raw)
629
+ if "seq_tag" not in tensor_dict.data:
630
+ tensor_dict.data["seq_tag"] = Tensor("seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor)
631
+
632
+ if seq_list_for_validation is not None:
633
+ seq_tag = seq_list_for_validation[seq_index]
634
+ tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item()
635
+ assert tag_of_seq == seq_tag, (
636
+ f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given"
637
+ )
638
+
639
+ return tensor_dict
640
+
641
+ assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream"
642
+ assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream"
643
+ if map_seq is not None:
644
+ data_iter = (_apply_map_seq(t_dict) for t_dict in data_iter)
645
+ if map_seq_stream is not None:
646
+ data_iter = map_seq_stream(data_iter, epoch=epoch, rng=rng, **util.get_fwd_compat_kwargs())
647
+ assert isinstance(data_iter, Iterator), (
648
+ f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}"
649
+ )
650
+ return _validate_tensor_dict_iter(data_iter)
651
+
652
+
653
+ class _MultiProcDataIter:
654
+ """
655
+ Data iter that pulls from the worker processes in a well-defined order and
656
+ manages the lifetime of the feeder thread.
657
+
658
+ Also ensures monotonicity of complete_frac, which would otherwise be no longer
659
+ guaranteed if there is more than one worker.
660
+ """
661
+
662
+ def __init__(
663
+ self, *, dataset_thread: threading.Thread, quit_event: threading.Event, worker_procs: List[_WorkerProcParent]
664
+ ):
665
+ self.dataset_thread = dataset_thread
666
+ self.quit_event = quit_event
667
+ assert len(worker_procs) > 0
668
+ self.worker_procs = worker_procs
669
+
670
+ self._complete_frac = 0.0 # need to force monotonicity
671
+ self._workers_exhausted = [False for _ in range(len(worker_procs))]
672
+ self._worker_idx = 0
673
+
674
+ def __iter__(self):
675
+ return self
676
+
677
+ def __next__(self) -> Optional[TensorDict]:
678
+ if self.quit_event.is_set():
679
+ raise StopIteration
680
+
681
+ while True:
682
+ if all(self._workers_exhausted):
683
+ break
684
+ worker_idx = self._worker_idx
685
+ self._worker_idx = (self._worker_idx + 1) % len(self.worker_procs)
686
+ if self._workers_exhausted[worker_idx]:
687
+ continue
688
+ seq = self.worker_procs[worker_idx].get_seq()
689
+ if seq is not None:
690
+ return self._ensure_complete_frac_monotonic(seq)
691
+ self._workers_exhausted[worker_idx] = True
692
+
693
+ # when we reach this point, all workers are exhausted and we stop
694
+ self.stop()
695
+ raise StopIteration
696
+
697
+ def stop(self, *, join=True):
698
+ """
699
+ Stop the iterator and the dataset thread.
700
+
701
+ Once this is called, the iterator cannot be used anymore.
702
+ """
703
+ if self.quit_event.is_set():
704
+ return
705
+ self.quit_event.set()
706
+ if join:
707
+ util.try_run(self.dataset_thread.join)
708
+
709
+ def _ensure_complete_frac_monotonic(self, seq: TensorDict) -> TensorDict:
710
+ """
711
+ Enforce monotonicity of `complete_frac` in the given `TensorDict`.
712
+ """
713
+ if "complete_frac" not in seq.data:
714
+ return seq
715
+ complete_frac = float(seq.data["complete_frac"].raw_tensor)
716
+ assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
717
+ self._complete_frac = max(complete_frac, self._complete_frac)
718
+ seq.data["complete_frac"].raw_tensor = numpy.array(self._complete_frac, dtype=numpy.float32)
719
+ return seq
720
+
721
+ def __del__(self):
722
+ # noinspection PyBroadException
723
+ try:
724
+ self.stop(join=False)
725
+ except Exception:
726
+ pass
727
+
728
+
729
+ class _WorkerProcParent:
730
+ def __init__(
731
+ self,
732
+ *,
733
+ buffer_size: int,
734
+ index: int,
735
+ name: str,
736
+ map_seq: Optional[Callable],
737
+ map_seq_stream: Optional[Callable],
738
+ out_tensor_dict_template: TensorDict,
739
+ ):
740
+ parent_conn, child_conn = _mp.Pipe()
741
+ self.parent_conn = parent_conn
742
+
743
+ self.worker_proc = _mp.Process(
744
+ name=f"{name} worker {index}",
745
+ target=_worker_proc_loop,
746
+ args=(index, child_conn, buffer_size, map_seq, map_seq_stream, out_tensor_dict_template),
747
+ daemon=True,
748
+ )
749
+ self.worker_proc.start()
750
+
751
+ # Make sure the child connection is closed here.
752
+ # It stays open in the child, until the child dies.
753
+ # When that happens, now any consecutive read on the pipe
754
+ # should yield an exception -- which is what we want,
755
+ # otherwise it would just hang.
756
+ child_conn.close()
757
+
758
+ def init_seq_order(
759
+ self,
760
+ *,
761
+ epoch: int,
762
+ rng_seed: int,
763
+ seq_list: Optional[List[str]],
764
+ seq_pipe: mpConnection,
765
+ ):
766
+ """init_seq_order"""
767
+ args = {"epoch": epoch, "rng_seed": rng_seed, "seq_list": seq_list, "seq_pipe": seq_pipe}
768
+ self.parent_conn.send(("init_seq_order", args))
769
+ msg, _ = self.parent_conn.recv()
770
+ assert msg == "init_seq_order"
771
+ # seq_pipe is owned by the child process,
772
+ # and so must be closed in the parent to avoid hangs
773
+ seq_pipe.close()
774
+
775
+ def get_seq(self) -> Optional[TensorDict]:
776
+ """get_seq"""
777
+ self.parent_conn.send(("get_seq", {}))
778
+ msg, seq = self.parent_conn.recv()
779
+ assert msg == "seq"
780
+ return seq
781
+
782
+ def exit(self, *, join: bool = True):
783
+ """exit"""
784
+ self.parent_conn.send(("exit", {}))
785
+ if join:
786
+ self.worker_proc.join()
787
+
788
+ def __del__(self):
789
+ # noinspection PyBroadException
790
+ try:
791
+ self.exit(join=False)
792
+ except Exception:
793
+ pass
794
+ else:
795
+ util.try_run(self.worker_proc.join)
796
+
797
+
798
+ def _worker_proc_loop(
799
+ index: int,
800
+ parent_conn: mpConnection,
801
+ buffer_size: int,
802
+ map_seq: Optional[Callable],
803
+ map_seq_stream: Optional[Callable],
804
+ out_tensor_dict_template: TensorDict,
805
+ ):
806
+ if sys.platform == "linux":
807
+ with open("/proc/self/comm", "w") as f:
808
+ f.write(f"PP worker {index}")
809
+ better_exchook.setup_all()
810
+
811
+ assert isinstance(buffer_size, int) and buffer_size > 0
812
+ assert isinstance(index, int)
813
+ assert isinstance(parent_conn, mpConnection)
814
+
815
+ cache: deque[TensorDict] = deque()
816
+
817
+ data_iter: Optional[Iterator[TensorDict]] = None
818
+ feeder_conn: Optional[mpConnection] = None
819
+
820
+ def _add_to_cache():
821
+ nonlocal data_iter
822
+ if data_iter is None:
823
+ return False
824
+ try:
825
+ seq = next(data_iter)
826
+ except StopIteration:
827
+ data_iter = None
828
+ return False
829
+ cache.append(seq)
830
+ return True
831
+
832
+ def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]:
833
+ assert isinstance(q, mpConnection)
834
+
835
+ while True:
836
+ try:
837
+ q.send(("get_seq", None))
838
+ seq_msg, item = q.recv()
839
+ except (BrokenPipeError, EOFError):
840
+ # queue is closed
841
+ break
842
+ assert seq_msg == "seq"
843
+ if item is None:
844
+ break
845
+ assert isinstance(item, TensorDict)
846
+ yield item
847
+
848
+ try:
849
+ while True:
850
+ while len(cache) < buffer_size and not parent_conn.poll():
851
+ if not _add_to_cache():
852
+ break
853
+ msg, kwargs = parent_conn.recv()
854
+ if msg == "exit":
855
+ break
856
+ elif msg == "get_seq":
857
+ if not cache:
858
+ _add_to_cache()
859
+ parent_conn.send(("seq", cache.popleft() if cache else None))
860
+ elif msg == "init_seq_order":
861
+ epoch = kwargs["epoch"]
862
+ if sys.platform == "linux":
863
+ with open("/proc/self/comm", "w") as f:
864
+ f.write(f"PP worker {index} ep {epoch}")
865
+ if feeder_conn is not None:
866
+ feeder_conn.close()
867
+ feeder_conn = kwargs["seq_pipe"]
868
+ data_iter = _build_mapping_iter(
869
+ _iter_pipe(feeder_conn),
870
+ epoch=epoch,
871
+ map_seq=map_seq,
872
+ map_seq_stream=map_seq_stream,
873
+ out_tensor_dict_template=out_tensor_dict_template,
874
+ rng=RandomState(kwargs["rng_seed"]),
875
+ seq_list_for_validation=kwargs["seq_list"],
876
+ )
877
+ assert isinstance(data_iter, Iterator)
878
+ cache.clear()
879
+ parent_conn.send(("init_seq_order", None))
880
+ else:
881
+ raise Exception(f"unknown msg {msg!r}")
882
+ except KeyboardInterrupt: # when parent dies
883
+ pass
884
+ except EOFError: # when parent dies
885
+ pass
886
+ finally:
887
+ if feeder_conn is not None:
888
+ feeder_conn.close()
889
+ parent_conn.close()
890
+
402
891
 
403
892
  class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]):
404
893
  """