returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/config.py +1 -1
- returnn/datasets/basic.py +29 -13
- returnn/datasets/distrib_files.py +61 -3
- returnn/datasets/generating.py +12 -21
- returnn/datasets/huggingface.py +434 -0
- returnn/datasets/lm.py +20 -0
- returnn/datasets/meta.py +179 -60
- returnn/datasets/multi_proc.py +1 -1
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/text_dict.py +1 -1
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/_backend.py +7 -0
- returnn/frontend/array_.py +54 -1
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/decoder/transformer.py +36 -17
- returnn/frontend/encoder/conformer.py +1 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +40 -1
- returnn/frontend/module.py +8 -1
- returnn/frontend/nested.py +9 -0
- returnn/native_op.cpp +80 -0
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +51 -29
- returnn/tensor/_tensor_extra.py +6 -1
- returnn/tensor/utils.py +7 -4
- returnn/tf/frontend_layers/_backend.py +11 -2
- returnn/tf/frontend_low_level/_backend.py +15 -0
- returnn/tf/layers/basic.py +16 -38
- returnn/tf/native_op.py +11 -58
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +19 -0
- returnn/torch/data/returnn_dataset_wrapper.py +9 -3
- returnn/torch/engine.py +67 -2
- returnn/torch/frontend/_backend.py +119 -7
- returnn/torch/util/diagnose_gpu.py +65 -31
- returnn/torch/util/exception_helper.py +7 -1
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/collect_outputs_dict.py +79 -0
- returnn/util/debug.py +11 -2
- returnn/util/file_cache.py +42 -4
- returnn/util/task_system.py +1 -1
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
|
@@ -4,20 +4,32 @@ Provides :class:`PostprocessingDataset`.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
from collections import deque
|
|
7
8
|
from itertools import islice
|
|
8
9
|
import numpy
|
|
9
10
|
from numpy.random import RandomState
|
|
10
|
-
|
|
11
|
+
import select
|
|
12
|
+
import sys
|
|
13
|
+
import threading
|
|
14
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, TypeVar
|
|
11
15
|
|
|
16
|
+
from returnn.config import SubProcCopyGlobalConfigPreInitFunc
|
|
12
17
|
from returnn.datasets.basic import DatasetSeq
|
|
13
18
|
from returnn.datasets.util.strings import str_to_numpy_array
|
|
14
19
|
from returnn.datasets.util.vocabulary import Vocabulary
|
|
15
20
|
from returnn.tensor import Tensor, TensorDict
|
|
16
21
|
from returnn.tensor.dim import Dim
|
|
17
|
-
from returnn.util import basic as util
|
|
18
|
-
from .
|
|
22
|
+
from returnn.util import basic as util, better_exchook
|
|
23
|
+
from returnn.util.multi_proc_non_daemonic_spawn import NonDaemonicSpawnContext
|
|
24
|
+
from .basic import Dataset, init_dataset
|
|
19
25
|
from .cached2 import CachedDataset2
|
|
20
26
|
|
|
27
|
+
# noinspection PyProtectedMember
|
|
28
|
+
from multiprocessing.connection import Connection as mpConnection
|
|
29
|
+
|
|
30
|
+
_mp = NonDaemonicSpawnContext(process_pre_init_func=SubProcCopyGlobalConfigPreInitFunc())
|
|
31
|
+
|
|
32
|
+
|
|
21
33
|
__all__ = ["PostprocessingDataset", "LaplaceOrdering", "Sequential"]
|
|
22
34
|
|
|
23
35
|
|
|
@@ -31,8 +43,15 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
31
43
|
SpecAugment or speed perturbation into the data loading pipeline.
|
|
32
44
|
|
|
33
45
|
The integration into the data loading pipeline makes it easy to distribute the
|
|
34
|
-
data processing work across multiple CPU cores
|
|
35
|
-
|
|
46
|
+
data processing work across multiple CPU cores and in turn frees the GPU from
|
|
47
|
+
data preprocessing tasks.
|
|
48
|
+
|
|
49
|
+
Multiprocessing can either be done using :class:``MultiProcDataset`` or by setting
|
|
50
|
+
`num_workers > 0` on this class.
|
|
51
|
+
|
|
52
|
+
The latter only applies parallelism to the post-processing functions themselves,
|
|
53
|
+
and does not duplicate the underlying dataset once per worker.
|
|
54
|
+
This is often fast enough and has the advantage of lower memory consumption.
|
|
36
55
|
|
|
37
56
|
Example usage::
|
|
38
57
|
|
|
@@ -61,8 +80,8 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
61
80
|
The postprocessor functions operate on ``TensorDict``s, which have entries for
|
|
62
81
|
all data keys in the underlying dataset.
|
|
63
82
|
|
|
64
|
-
There may also be additional "meta" entries in the tensor dicts, like ``complete_frac
|
|
65
|
-
and ``seq_tag``.
|
|
83
|
+
There may also be additional "meta" entries in the tensor dicts, like ``complete_frac``,
|
|
84
|
+
``seq_idx`` and ``seq_tag``.
|
|
66
85
|
These should be copied over in a manner that is reasonable for the use case at hand and
|
|
67
86
|
ensures forwards compatibility as well as reasonably possible.
|
|
68
87
|
|
|
@@ -93,11 +112,14 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
93
112
|
|
|
94
113
|
def __init__(
|
|
95
114
|
self,
|
|
115
|
+
*,
|
|
96
116
|
dataset: Dict[str, Any],
|
|
97
117
|
map_seq: Optional[Callable] = None,
|
|
98
118
|
map_seq_stream: Optional[Callable] = None,
|
|
99
119
|
map_outputs: Optional[Dict[str, Any]] = None,
|
|
100
120
|
map_seq_stream_preserves_num_seqs: Optional[bool] = None,
|
|
121
|
+
buf_size: int = 1,
|
|
122
|
+
num_workers: int = 0,
|
|
101
123
|
**kwargs,
|
|
102
124
|
):
|
|
103
125
|
"""
|
|
@@ -123,6 +145,11 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
123
145
|
Example: `map_outputs={"data": {"dim": 42}}`
|
|
124
146
|
:param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of
|
|
125
147
|
sequences, i.e. for every input sequence there is exactly one output sequence.
|
|
148
|
+
:param buf_size: Buffer size for each worker, number of seqs to prefetch. Must be > 0.
|
|
149
|
+
:param num_workers: If > 0, configures the number of worker processes to use for data postprocessing.
|
|
150
|
+
Only the postprocessing is distributed across subprocesses,
|
|
151
|
+
the underlying dataset is only instantiated once.
|
|
152
|
+
This usually has lower memory consumption than using :class:``MultiProcDataset``.
|
|
126
153
|
:param kwargs: see :class:`CachedDataset2`, :class:`Dataset`
|
|
127
154
|
"""
|
|
128
155
|
super().__init__(**kwargs)
|
|
@@ -136,6 +163,11 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
136
163
|
if map_seq and map_seq_stream_preserves_num_seqs is not None:
|
|
137
164
|
raise ValueError(f"{self}: map_seq_stream_preserves_num_seqs is only allowed with map_seq_stream")
|
|
138
165
|
|
|
166
|
+
if buf_size < 1:
|
|
167
|
+
raise ValueError(f"{self}: buf_size must be > 0, but got {buf_size}")
|
|
168
|
+
if num_workers < 0:
|
|
169
|
+
raise ValueError(f"{self}: num_workers must be >= 0, but got {num_workers}")
|
|
170
|
+
|
|
139
171
|
self._dataset_def = dataset
|
|
140
172
|
self._map_seq = map_seq
|
|
141
173
|
self._map_seq_stream = map_seq_stream
|
|
@@ -144,7 +176,6 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
144
176
|
assert map_seq_stream_preserves_num_seqs is None or isinstance(map_seq_stream_preserves_num_seqs, bool)
|
|
145
177
|
self._map_seq_stream_preserves_num_seqs = map_seq_stream_preserves_num_seqs
|
|
146
178
|
self._map_outputs = map_outputs
|
|
147
|
-
self._rng = RandomState(self._get_random_seed_for_epoch(0))
|
|
148
179
|
self._seq_list_for_validation: Optional[List[str]] = None
|
|
149
180
|
|
|
150
181
|
self._dataset = init_dataset(self._dataset_def, parent_dataset=self)
|
|
@@ -154,6 +185,14 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
154
185
|
self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None
|
|
155
186
|
self._data_iter_produced_num_seqs = 0
|
|
156
187
|
|
|
188
|
+
self._buf_size = buf_size
|
|
189
|
+
# Ensure only one feeder thread at a time accesses the wrapped dataset to
|
|
190
|
+
# prevent race conditions while moving from one epoch to the next.
|
|
191
|
+
self._dataset_lock = threading.Lock()
|
|
192
|
+
self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup
|
|
193
|
+
self._num_workers = num_workers
|
|
194
|
+
self._worker_procs: Optional[List[_WorkerProcParent]] = None
|
|
195
|
+
|
|
157
196
|
self._in_tensor_dict_template = TensorDict(
|
|
158
197
|
{name: self._make_tensor_template_from_input(name) for name in self._dataset.get_data_keys()}
|
|
159
198
|
)
|
|
@@ -166,7 +205,11 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
166
205
|
self.labels = self._dataset.labels.copy()
|
|
167
206
|
# update only after _out_tensor_dict_template has been created from _in_tensor_dict_template
|
|
168
207
|
self._in_tensor_dict_template.update(
|
|
169
|
-
{
|
|
208
|
+
{
|
|
209
|
+
"complete_frac": {"dims": (), "dtype": "float32"},
|
|
210
|
+
"seq_idx": {"dims": (), "dtype": "int32"},
|
|
211
|
+
"seq_tag": {"dims": (), "dtype": "string"},
|
|
212
|
+
},
|
|
170
213
|
auto_convert=True,
|
|
171
214
|
)
|
|
172
215
|
self.num_outputs = {
|
|
@@ -201,14 +244,41 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
201
244
|
if seq_order is not None:
|
|
202
245
|
raise ValueError("map_seq_stream is set, cannot specify custom seq_order")
|
|
203
246
|
|
|
247
|
+
if self._multi_proc_data_iter is not None:
|
|
248
|
+
self._multi_proc_data_iter.stop()
|
|
249
|
+
self._multi_proc_data_iter = None
|
|
250
|
+
|
|
204
251
|
if epoch is None and seq_list is None and seq_order is None:
|
|
205
252
|
self._num_seqs = 0
|
|
206
253
|
return True
|
|
207
254
|
|
|
208
|
-
self.
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
255
|
+
if self._num_workers > 0:
|
|
256
|
+
self._lazy_init_worker_procs()
|
|
257
|
+
assert self._worker_procs is not None and len(self._worker_procs) == self._num_workers
|
|
258
|
+
parent_conns, child_conns = zip(*[_mp.Pipe() for _ in range(self._num_workers)])
|
|
259
|
+
base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 683859 * self._num_workers
|
|
260
|
+
for i, (worker, child_conn) in enumerate(zip(self._worker_procs, child_conns)):
|
|
261
|
+
worker.init_seq_order(
|
|
262
|
+
epoch=epoch,
|
|
263
|
+
rng_seed=(base_rng_seed + 30411 * i) % (2**32 - 1),
|
|
264
|
+
seq_list=seq_list,
|
|
265
|
+
seq_pipe=child_conn,
|
|
266
|
+
)
|
|
267
|
+
data_iter = self._multi_proc_data_iter = self._init_multi_proc_data_iter(
|
|
268
|
+
epoch=epoch, feeder_to_worker_conns=parent_conns, seq_list=seq_list, seq_order=seq_order
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
|
|
272
|
+
data_iter = _build_mapping_iter(
|
|
273
|
+
_iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template),
|
|
274
|
+
map_seq=self._map_seq,
|
|
275
|
+
map_seq_stream=self._map_seq_stream,
|
|
276
|
+
epoch=epoch,
|
|
277
|
+
out_tensor_dict_template=self._out_tensor_dict_template,
|
|
278
|
+
rng=RandomState(self._get_random_seed_for_epoch(epoch=epoch)),
|
|
279
|
+
seq_list_for_validation=seq_list,
|
|
280
|
+
)
|
|
281
|
+
self._data_iter = enumerate(data_iter)
|
|
212
282
|
self._data_iter_produced_num_seqs = 0
|
|
213
283
|
self._seq_list_for_validation = seq_list
|
|
214
284
|
if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True:
|
|
@@ -220,6 +290,24 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
220
290
|
pass # some datasets don't know their num_seqs
|
|
221
291
|
return True
|
|
222
292
|
|
|
293
|
+
def __del__(self):
|
|
294
|
+
if self._multi_proc_data_iter is not None:
|
|
295
|
+
self._multi_proc_data_iter.stop(join=True)
|
|
296
|
+
self._multi_proc_data_iter = None
|
|
297
|
+
if not self._worker_procs:
|
|
298
|
+
return
|
|
299
|
+
got_exception = False
|
|
300
|
+
for parent in self._worker_procs:
|
|
301
|
+
# noinspection PyBroadException
|
|
302
|
+
try:
|
|
303
|
+
parent.exit(join=False)
|
|
304
|
+
except Exception:
|
|
305
|
+
got_exception = True
|
|
306
|
+
if got_exception:
|
|
307
|
+
return
|
|
308
|
+
for parent in self._worker_procs:
|
|
309
|
+
util.try_run(parent.worker_proc.join)
|
|
310
|
+
|
|
223
311
|
def get_current_seq_order(self):
|
|
224
312
|
""":return: current seq order of wrapped dataset, if map_seq_stream is not used"""
|
|
225
313
|
if self._map_seq_stream is not None:
|
|
@@ -256,6 +344,19 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
256
344
|
assert self._dataset is not None
|
|
257
345
|
return self._dataset.supports_sharding()
|
|
258
346
|
|
|
347
|
+
def finish_epoch(self, *, free_resources=False):
|
|
348
|
+
"""finish_epoch"""
|
|
349
|
+
super().finish_epoch(free_resources=free_resources)
|
|
350
|
+
if not free_resources:
|
|
351
|
+
return
|
|
352
|
+
if self._multi_proc_data_iter is not None:
|
|
353
|
+
self._multi_proc_data_iter.stop(join=True)
|
|
354
|
+
self._multi_proc_data_iter = None
|
|
355
|
+
if self._worker_procs is not None:
|
|
356
|
+
for wp in self._worker_procs:
|
|
357
|
+
wp.exit(join=True)
|
|
358
|
+
self._worker_procs = None
|
|
359
|
+
|
|
259
360
|
def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
|
|
260
361
|
while True:
|
|
261
362
|
try:
|
|
@@ -286,101 +387,6 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
286
387
|
seq = DatasetSeq(complete_frac=complete_frac, features=features, seq_idx=seq_idx, seq_tag=seq_tag)
|
|
287
388
|
return seq
|
|
288
389
|
|
|
289
|
-
def _build_mapping_iter(self) -> Iterator[TensorDict]:
|
|
290
|
-
"""
|
|
291
|
-
:return: an iterator applying both the segment level and across-segment transformations on the given dataset
|
|
292
|
-
"""
|
|
293
|
-
|
|
294
|
-
def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]:
|
|
295
|
-
last_complete_frac = 0.0
|
|
296
|
-
for t_dict in inner:
|
|
297
|
-
assert isinstance(t_dict, TensorDict), (
|
|
298
|
-
f"postprocessing mapper function must produce a {TensorDict.__name__}, "
|
|
299
|
-
f"but got a {type(t_dict).__name__}"
|
|
300
|
-
)
|
|
301
|
-
if "complete_frac" in t_dict.data: # sanity check complete_frac
|
|
302
|
-
complete_frac = float(t_dict.data["complete_frac"].raw_tensor)
|
|
303
|
-
assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
|
|
304
|
-
assert complete_frac >= last_complete_frac, (
|
|
305
|
-
"complete_frac must be monotonically increasing, "
|
|
306
|
-
f"but got {complete_frac} after {last_complete_frac}"
|
|
307
|
-
)
|
|
308
|
-
last_complete_frac = complete_frac
|
|
309
|
-
for data_key, out_t in self._out_tensor_dict_template.data.items():
|
|
310
|
-
in_t = t_dict.data[data_key]
|
|
311
|
-
assert in_t.ndim == out_t.batch_ndim, (
|
|
312
|
-
f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. "
|
|
313
|
-
"Postprocessing data tensors must not have a batch dimension."
|
|
314
|
-
)
|
|
315
|
-
assert in_t.dtype == out_t.dtype, (
|
|
316
|
-
f"dtype mismatch for {data_key}: '{in_t.dtype}' != '{out_t.dtype}'"
|
|
317
|
-
)
|
|
318
|
-
for i, (in_dim, out_shape) in enumerate(zip(in_t.dims, out_t.shape)):
|
|
319
|
-
assert in_dim.dimension is None or in_dim.dimension == out_shape, (
|
|
320
|
-
f"Dim {i} mismatch on {data_key}: "
|
|
321
|
-
f"{in_dim.dimension} must either be `None` or equal {out_shape}"
|
|
322
|
-
)
|
|
323
|
-
yield t_dict
|
|
324
|
-
|
|
325
|
-
data_iter = self._iterate_dataset()
|
|
326
|
-
if self._map_seq_stream is not None:
|
|
327
|
-
data_iter = self._map_seq_stream(data_iter, epoch=self.epoch, rng=self._rng, **util.get_fwd_compat_kwargs())
|
|
328
|
-
assert isinstance(data_iter, Iterator), (
|
|
329
|
-
f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}"
|
|
330
|
-
)
|
|
331
|
-
return _validate_tensor_dict_iter(data_iter)
|
|
332
|
-
|
|
333
|
-
def _iterate_dataset(self) -> Iterator[TensorDict]:
|
|
334
|
-
"""
|
|
335
|
-
:return: generator providing data samples in the form of a TensorDict
|
|
336
|
-
"""
|
|
337
|
-
data_keys = self._dataset.get_data_keys()
|
|
338
|
-
|
|
339
|
-
seq_index = 0
|
|
340
|
-
while self._dataset.is_less_than_num_seqs(seq_index):
|
|
341
|
-
self._dataset.load_seqs(seq_index, seq_index + 1)
|
|
342
|
-
|
|
343
|
-
tensor_dict = self._in_tensor_dict_template.copy_template()
|
|
344
|
-
for data_key in data_keys:
|
|
345
|
-
tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key)
|
|
346
|
-
|
|
347
|
-
complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True)
|
|
348
|
-
comp_frac_raw_tensor = None
|
|
349
|
-
if complete_frac is not None:
|
|
350
|
-
comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32)
|
|
351
|
-
tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor
|
|
352
|
-
seq_tag_raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
|
|
353
|
-
tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor
|
|
354
|
-
|
|
355
|
-
if self._map_seq is not None:
|
|
356
|
-
tensor_dict = self._map_seq(
|
|
357
|
-
tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs()
|
|
358
|
-
)
|
|
359
|
-
assert isinstance(tensor_dict, TensorDict), (
|
|
360
|
-
f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
|
|
361
|
-
)
|
|
362
|
-
|
|
363
|
-
# Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped
|
|
364
|
-
# since we don't add/drop any segments w/ the non-iterator postprocessing function.
|
|
365
|
-
if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None:
|
|
366
|
-
tensor_dict.data["complete_frac"] = Tensor(
|
|
367
|
-
"complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor
|
|
368
|
-
)
|
|
369
|
-
if "seq_tag" not in tensor_dict.data:
|
|
370
|
-
tensor_dict.data["seq_tag"] = Tensor(
|
|
371
|
-
"seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
if self._seq_list_for_validation is not None:
|
|
375
|
-
seq_tag = self._seq_list_for_validation[seq_index]
|
|
376
|
-
tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item()
|
|
377
|
-
assert tag_of_seq == seq_tag, (
|
|
378
|
-
f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given"
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
yield tensor_dict
|
|
382
|
-
seq_index += 1
|
|
383
|
-
|
|
384
390
|
def _make_tensor_template_from_input(self, data_key: str) -> Tensor:
|
|
385
391
|
dtype = self._dataset.get_data_dtype(data_key)
|
|
386
392
|
if dtype == "string":
|
|
@@ -399,6 +405,489 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
399
405
|
sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key])
|
|
400
406
|
return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim)
|
|
401
407
|
|
|
408
|
+
def _lazy_init_worker_procs(self):
|
|
409
|
+
if self._worker_procs is not None:
|
|
410
|
+
return
|
|
411
|
+
self._worker_procs = [
|
|
412
|
+
_WorkerProcParent(
|
|
413
|
+
name=f"{self.__class__.__name__} {self.name} worker",
|
|
414
|
+
buffer_size=self._buf_size,
|
|
415
|
+
index=i,
|
|
416
|
+
map_seq=self._map_seq,
|
|
417
|
+
map_seq_stream=self._map_seq_stream,
|
|
418
|
+
out_tensor_dict_template=self._out_tensor_dict_template,
|
|
419
|
+
)
|
|
420
|
+
for i in range(self._num_workers)
|
|
421
|
+
]
|
|
422
|
+
|
|
423
|
+
def _init_multi_proc_data_iter(
|
|
424
|
+
self,
|
|
425
|
+
*,
|
|
426
|
+
epoch: int,
|
|
427
|
+
feeder_to_worker_conns: Sequence[mpConnection],
|
|
428
|
+
seq_list: Optional[List[str]] = None,
|
|
429
|
+
seq_order: Optional[List[int]] = None,
|
|
430
|
+
) -> _MultiProcDataIter:
|
|
431
|
+
assert len(feeder_to_worker_conns) == self._num_workers
|
|
432
|
+
|
|
433
|
+
quit_event = threading.Event()
|
|
434
|
+
dataset_thread = threading.Thread(
|
|
435
|
+
target=self._init_seq_order_and_distribute_seqs_to_children,
|
|
436
|
+
kwargs={
|
|
437
|
+
"epoch": epoch,
|
|
438
|
+
"quit_event": quit_event,
|
|
439
|
+
"seq_list": seq_list,
|
|
440
|
+
"seq_order": seq_order,
|
|
441
|
+
"worker_conns": feeder_to_worker_conns,
|
|
442
|
+
},
|
|
443
|
+
name=f"{self.__class__.__name__} feeder ep {epoch}",
|
|
444
|
+
)
|
|
445
|
+
# parent_conns are not closed here, because they move to a different thread, not process,
|
|
446
|
+
# and so they must remain open.
|
|
447
|
+
dataset_thread.start()
|
|
448
|
+
data_iter = _MultiProcDataIter(
|
|
449
|
+
dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=self._worker_procs
|
|
450
|
+
)
|
|
451
|
+
return data_iter
|
|
452
|
+
|
|
453
|
+
def _init_seq_order_and_distribute_seqs_to_children(
|
|
454
|
+
self,
|
|
455
|
+
*,
|
|
456
|
+
epoch: int,
|
|
457
|
+
quit_event: threading.Event,
|
|
458
|
+
seq_list: Optional[List[str]] = None,
|
|
459
|
+
seq_order: Optional[List[int]] = None,
|
|
460
|
+
worker_conns: Sequence[mpConnection],
|
|
461
|
+
):
|
|
462
|
+
"""
|
|
463
|
+
Initialize the wrapped dataset and distribute the contained sequences to the child worker processes.
|
|
464
|
+
"""
|
|
465
|
+
|
|
466
|
+
assert self._buf_size > 0
|
|
467
|
+
assert len(worker_conns) > 0
|
|
468
|
+
assert self._num_workers > 0
|
|
469
|
+
|
|
470
|
+
caches: List[deque[TensorDict]] = [deque() for _ in range(len(worker_conns))]
|
|
471
|
+
|
|
472
|
+
def _any_conn_ready() -> bool:
|
|
473
|
+
ready, _, _ = select.select(worker_conns, [], [], 0)
|
|
474
|
+
return len(ready) > 0
|
|
475
|
+
|
|
476
|
+
def _maybe_distrib_seq(*, timeout=0.1):
|
|
477
|
+
assert timeout >= 0.0
|
|
478
|
+
# do not block indefinitely to periodically check the quit_event
|
|
479
|
+
ready_conns, _, _ = select.select(worker_conns, [], [], timeout)
|
|
480
|
+
assert len(worker_conns) == len(caches)
|
|
481
|
+
for child_queue, cache in zip(worker_conns, caches):
|
|
482
|
+
if child_queue not in ready_conns:
|
|
483
|
+
continue
|
|
484
|
+
msg, _ = child_queue.recv()
|
|
485
|
+
assert msg == "get_seq"
|
|
486
|
+
tensor_dict = cache.popleft() if len(cache) > 0 else None
|
|
487
|
+
child_queue.send(("seq", tensor_dict))
|
|
488
|
+
|
|
489
|
+
# Lock ensures that only one thread at a time accesses the wrapped dataset.
|
|
490
|
+
# This protects against issues while moving from one epoch to the next.
|
|
491
|
+
with self._dataset_lock:
|
|
492
|
+
self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
|
|
493
|
+
data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template)
|
|
494
|
+
data_iter = enumerate(data_iter)
|
|
495
|
+
|
|
496
|
+
def _add_to_cache() -> bool:
|
|
497
|
+
try:
|
|
498
|
+
idx, tensor_dict = next(data_iter)
|
|
499
|
+
caches[idx % len(caches)].append(tensor_dict)
|
|
500
|
+
return True
|
|
501
|
+
except StopIteration:
|
|
502
|
+
return False
|
|
503
|
+
|
|
504
|
+
while not quit_event.is_set():
|
|
505
|
+
# fetch seqs until all caches have at least one seq,
|
|
506
|
+
# if no child is waiting for seqs also fill until buf_size
|
|
507
|
+
while any(len(cache) == 0 for cache in caches) or (
|
|
508
|
+
sum(len(cache) for cache in caches) < self._buf_size and not _any_conn_ready()
|
|
509
|
+
):
|
|
510
|
+
if not _add_to_cache():
|
|
511
|
+
break
|
|
512
|
+
if all(len(c) == 0 for c in caches):
|
|
513
|
+
break
|
|
514
|
+
try:
|
|
515
|
+
_maybe_distrib_seq()
|
|
516
|
+
except (BrokenPipeError, EOFError):
|
|
517
|
+
# queue is closed, i.e. the worker process crashed for some reason -> stop
|
|
518
|
+
break
|
|
519
|
+
|
|
520
|
+
for queue in worker_conns:
|
|
521
|
+
try:
|
|
522
|
+
queue.send(("seq", None))
|
|
523
|
+
except (BrokenPipeError, EOFError):
|
|
524
|
+
# queue is already closed, i.e. the worker process died
|
|
525
|
+
pass
|
|
526
|
+
finally:
|
|
527
|
+
queue.close()
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def _iterate_dataset(dataset: Dataset, *, in_tensor_dict_template: TensorDict) -> Iterator[TensorDict]:
|
|
531
|
+
"""
|
|
532
|
+
:return: generator providing data samples in the form of a TensorDict
|
|
533
|
+
"""
|
|
534
|
+
data_keys = dataset.get_data_keys()
|
|
535
|
+
|
|
536
|
+
seq_index = 0
|
|
537
|
+
while dataset.is_less_than_num_seqs(seq_index):
|
|
538
|
+
dataset.load_seqs(seq_index, seq_index + 1)
|
|
539
|
+
|
|
540
|
+
tensor_dict = in_tensor_dict_template.copy_template()
|
|
541
|
+
for data_key in data_keys:
|
|
542
|
+
tensor_dict.data[data_key].raw_tensor = dataset.get_data(seq_index, data_key)
|
|
543
|
+
|
|
544
|
+
complete_frac = dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True)
|
|
545
|
+
if complete_frac is not None:
|
|
546
|
+
comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32)
|
|
547
|
+
tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor
|
|
548
|
+
seq_idx_raw_tensor = numpy.array(seq_index, dtype=numpy.int32)
|
|
549
|
+
tensor_dict.data["seq_idx"].raw_tensor = seq_idx_raw_tensor
|
|
550
|
+
seq_tag_raw_tensor = str_to_numpy_array(dataset.get_tag(seq_index))
|
|
551
|
+
tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor
|
|
552
|
+
|
|
553
|
+
yield tensor_dict
|
|
554
|
+
seq_index += 1
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def _build_mapping_iter(
|
|
558
|
+
data_iter: Iterator[TensorDict],
|
|
559
|
+
*,
|
|
560
|
+
map_seq: Optional[Callable] = None,
|
|
561
|
+
map_seq_stream: Optional[Callable] = None,
|
|
562
|
+
epoch: int,
|
|
563
|
+
out_tensor_dict_template: TensorDict,
|
|
564
|
+
rng: RandomState,
|
|
565
|
+
seq_list_for_validation: Optional[List[str]] = None,
|
|
566
|
+
) -> Iterator[TensorDict]:
|
|
567
|
+
"""
|
|
568
|
+
Build an iterator applying the mapping functions on the given dataset iterator.
|
|
569
|
+
|
|
570
|
+
:param data_iter: iterator providing data samples in the form of a TensorDict
|
|
571
|
+
:param map_seq: see :class:`PostprocessingDataset`
|
|
572
|
+
:param map_seq_stream: see :class:`PostprocessingDataset`
|
|
573
|
+
:param epoch: current epoch number
|
|
574
|
+
:param out_tensor_dict_template: template for the output TensorDicts, used for validation
|
|
575
|
+
:param rng: random number generator to use
|
|
576
|
+
:param seq_list_for_validation: optional list of seq tags to validate against when processing the data
|
|
577
|
+
:return: an iterator applying both the segment level and across-segment transformations on the given dataset
|
|
578
|
+
"""
|
|
579
|
+
|
|
580
|
+
def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]:
|
|
581
|
+
last_complete_frac = 0.0
|
|
582
|
+
for t_dict in inner:
|
|
583
|
+
assert isinstance(t_dict, TensorDict), (
|
|
584
|
+
f"postprocessing mapper function must produce a {TensorDict.__name__}, "
|
|
585
|
+
f"but got a {type(t_dict).__name__}"
|
|
586
|
+
)
|
|
587
|
+
if "complete_frac" in t_dict.data: # sanity check complete_frac
|
|
588
|
+
complete_frac = float(t_dict.data["complete_frac"].raw_tensor)
|
|
589
|
+
assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
|
|
590
|
+
assert complete_frac >= last_complete_frac, (
|
|
591
|
+
"complete_frac must be monotonically increasing, "
|
|
592
|
+
f"but got {complete_frac} after {last_complete_frac}"
|
|
593
|
+
)
|
|
594
|
+
last_complete_frac = complete_frac
|
|
595
|
+
for data_key, out_t in out_tensor_dict_template.data.items():
|
|
596
|
+
in_t = t_dict.data[data_key]
|
|
597
|
+
assert in_t.ndim == out_t.batch_ndim, (
|
|
598
|
+
f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. "
|
|
599
|
+
"Postprocessing data tensors must not have a batch dimension."
|
|
600
|
+
)
|
|
601
|
+
assert in_t.dtype == out_t.dtype, f"dtype mismatch for {data_key}: '{in_t.dtype}' != '{out_t.dtype}'"
|
|
602
|
+
for i, (in_dim, out_shape) in enumerate(zip(in_t.dims, out_t.shape)):
|
|
603
|
+
assert in_dim.dimension is None or in_dim.dimension == out_shape, (
|
|
604
|
+
f"Dim {i} mismatch on {data_key}: {in_dim.dimension} must either be `None` or equal {out_shape}"
|
|
605
|
+
)
|
|
606
|
+
yield t_dict
|
|
607
|
+
|
|
608
|
+
def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict:
|
|
609
|
+
comp_frac_raw_tensor = (
|
|
610
|
+
tensor_dict.data["complete_frac"].raw_tensor if "complete_frac" in tensor_dict.data else None
|
|
611
|
+
)
|
|
612
|
+
seq_index_raw = tensor_dict.data["seq_idx"].raw_tensor
|
|
613
|
+
seq_index = int(seq_index_raw.item())
|
|
614
|
+
seq_tag_raw_tensor = tensor_dict.data["seq_tag"].raw_tensor
|
|
615
|
+
|
|
616
|
+
tensor_dict = map_seq(tensor_dict, epoch=epoch, seq_idx=seq_index, rng=rng, **util.get_fwd_compat_kwargs())
|
|
617
|
+
assert isinstance(tensor_dict, TensorDict), (
|
|
618
|
+
f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# Re-adding the complete_frac/seq_idx/seq_tag here causes no harm in case they are dropped
|
|
622
|
+
# since we don't add/drop any segments w/ the non-iterator postprocessing function.
|
|
623
|
+
if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None:
|
|
624
|
+
tensor_dict.data["complete_frac"] = Tensor(
|
|
625
|
+
"complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor
|
|
626
|
+
)
|
|
627
|
+
if "seq_idx" not in tensor_dict.data:
|
|
628
|
+
tensor_dict.data["seq_idx"] = Tensor("seq_idx", dims=(), dtype="int32", raw_tensor=seq_index_raw)
|
|
629
|
+
if "seq_tag" not in tensor_dict.data:
|
|
630
|
+
tensor_dict.data["seq_tag"] = Tensor("seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor)
|
|
631
|
+
|
|
632
|
+
if seq_list_for_validation is not None:
|
|
633
|
+
seq_tag = seq_list_for_validation[seq_index]
|
|
634
|
+
tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item()
|
|
635
|
+
assert tag_of_seq == seq_tag, (
|
|
636
|
+
f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given"
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
return tensor_dict
|
|
640
|
+
|
|
641
|
+
assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream"
|
|
642
|
+
assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream"
|
|
643
|
+
if map_seq is not None:
|
|
644
|
+
data_iter = (_apply_map_seq(t_dict) for t_dict in data_iter)
|
|
645
|
+
if map_seq_stream is not None:
|
|
646
|
+
data_iter = map_seq_stream(data_iter, epoch=epoch, rng=rng, **util.get_fwd_compat_kwargs())
|
|
647
|
+
assert isinstance(data_iter, Iterator), (
|
|
648
|
+
f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}"
|
|
649
|
+
)
|
|
650
|
+
return _validate_tensor_dict_iter(data_iter)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
class _MultiProcDataIter:
|
|
654
|
+
"""
|
|
655
|
+
Data iter that pulls from the worker processes in a well-defined order and
|
|
656
|
+
manages the lifetime of the feeder thread.
|
|
657
|
+
|
|
658
|
+
Also ensures monotonicity of complete_frac, which would otherwise be no longer
|
|
659
|
+
guaranteed if there is more than one worker.
|
|
660
|
+
"""
|
|
661
|
+
|
|
662
|
+
def __init__(
|
|
663
|
+
self, *, dataset_thread: threading.Thread, quit_event: threading.Event, worker_procs: List[_WorkerProcParent]
|
|
664
|
+
):
|
|
665
|
+
self.dataset_thread = dataset_thread
|
|
666
|
+
self.quit_event = quit_event
|
|
667
|
+
assert len(worker_procs) > 0
|
|
668
|
+
self.worker_procs = worker_procs
|
|
669
|
+
|
|
670
|
+
self._complete_frac = 0.0 # need to force monotonicity
|
|
671
|
+
self._workers_exhausted = [False for _ in range(len(worker_procs))]
|
|
672
|
+
self._worker_idx = 0
|
|
673
|
+
|
|
674
|
+
def __iter__(self):
|
|
675
|
+
return self
|
|
676
|
+
|
|
677
|
+
def __next__(self) -> Optional[TensorDict]:
|
|
678
|
+
if self.quit_event.is_set():
|
|
679
|
+
raise StopIteration
|
|
680
|
+
|
|
681
|
+
while True:
|
|
682
|
+
if all(self._workers_exhausted):
|
|
683
|
+
break
|
|
684
|
+
worker_idx = self._worker_idx
|
|
685
|
+
self._worker_idx = (self._worker_idx + 1) % len(self.worker_procs)
|
|
686
|
+
if self._workers_exhausted[worker_idx]:
|
|
687
|
+
continue
|
|
688
|
+
seq = self.worker_procs[worker_idx].get_seq()
|
|
689
|
+
if seq is not None:
|
|
690
|
+
return self._ensure_complete_frac_monotonic(seq)
|
|
691
|
+
self._workers_exhausted[worker_idx] = True
|
|
692
|
+
|
|
693
|
+
# when we reach this point, all workers are exhausted and we stop
|
|
694
|
+
self.stop()
|
|
695
|
+
raise StopIteration
|
|
696
|
+
|
|
697
|
+
def stop(self, *, join=True):
|
|
698
|
+
"""
|
|
699
|
+
Stop the iterator and the dataset thread.
|
|
700
|
+
|
|
701
|
+
Once this is called, the iterator cannot be used anymore.
|
|
702
|
+
"""
|
|
703
|
+
if self.quit_event.is_set():
|
|
704
|
+
return
|
|
705
|
+
self.quit_event.set()
|
|
706
|
+
if join:
|
|
707
|
+
util.try_run(self.dataset_thread.join)
|
|
708
|
+
|
|
709
|
+
def _ensure_complete_frac_monotonic(self, seq: TensorDict) -> TensorDict:
|
|
710
|
+
"""
|
|
711
|
+
Enforce monotonicity of `complete_frac` in the given `TensorDict`.
|
|
712
|
+
"""
|
|
713
|
+
if "complete_frac" not in seq.data:
|
|
714
|
+
return seq
|
|
715
|
+
complete_frac = float(seq.data["complete_frac"].raw_tensor)
|
|
716
|
+
assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
|
|
717
|
+
self._complete_frac = max(complete_frac, self._complete_frac)
|
|
718
|
+
seq.data["complete_frac"].raw_tensor = numpy.array(self._complete_frac, dtype=numpy.float32)
|
|
719
|
+
return seq
|
|
720
|
+
|
|
721
|
+
def __del__(self):
|
|
722
|
+
# noinspection PyBroadException
|
|
723
|
+
try:
|
|
724
|
+
self.stop(join=False)
|
|
725
|
+
except Exception:
|
|
726
|
+
pass
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class _WorkerProcParent:
|
|
730
|
+
def __init__(
|
|
731
|
+
self,
|
|
732
|
+
*,
|
|
733
|
+
buffer_size: int,
|
|
734
|
+
index: int,
|
|
735
|
+
name: str,
|
|
736
|
+
map_seq: Optional[Callable],
|
|
737
|
+
map_seq_stream: Optional[Callable],
|
|
738
|
+
out_tensor_dict_template: TensorDict,
|
|
739
|
+
):
|
|
740
|
+
parent_conn, child_conn = _mp.Pipe()
|
|
741
|
+
self.parent_conn = parent_conn
|
|
742
|
+
|
|
743
|
+
self.worker_proc = _mp.Process(
|
|
744
|
+
name=f"{name} worker {index}",
|
|
745
|
+
target=_worker_proc_loop,
|
|
746
|
+
args=(index, child_conn, buffer_size, map_seq, map_seq_stream, out_tensor_dict_template),
|
|
747
|
+
daemon=True,
|
|
748
|
+
)
|
|
749
|
+
self.worker_proc.start()
|
|
750
|
+
|
|
751
|
+
# Make sure the child connection is closed here.
|
|
752
|
+
# It stays open in the child, until the child dies.
|
|
753
|
+
# When that happens, now any consecutive read on the pipe
|
|
754
|
+
# should yield an exception -- which is what we want,
|
|
755
|
+
# otherwise it would just hang.
|
|
756
|
+
child_conn.close()
|
|
757
|
+
|
|
758
|
+
def init_seq_order(
|
|
759
|
+
self,
|
|
760
|
+
*,
|
|
761
|
+
epoch: int,
|
|
762
|
+
rng_seed: int,
|
|
763
|
+
seq_list: Optional[List[str]],
|
|
764
|
+
seq_pipe: mpConnection,
|
|
765
|
+
):
|
|
766
|
+
"""init_seq_order"""
|
|
767
|
+
args = {"epoch": epoch, "rng_seed": rng_seed, "seq_list": seq_list, "seq_pipe": seq_pipe}
|
|
768
|
+
self.parent_conn.send(("init_seq_order", args))
|
|
769
|
+
msg, _ = self.parent_conn.recv()
|
|
770
|
+
assert msg == "init_seq_order"
|
|
771
|
+
# seq_pipe is owned by the child process,
|
|
772
|
+
# and so must be closed in the parent to avoid hangs
|
|
773
|
+
seq_pipe.close()
|
|
774
|
+
|
|
775
|
+
def get_seq(self) -> Optional[TensorDict]:
|
|
776
|
+
"""get_seq"""
|
|
777
|
+
self.parent_conn.send(("get_seq", {}))
|
|
778
|
+
msg, seq = self.parent_conn.recv()
|
|
779
|
+
assert msg == "seq"
|
|
780
|
+
return seq
|
|
781
|
+
|
|
782
|
+
def exit(self, *, join: bool = True):
|
|
783
|
+
"""exit"""
|
|
784
|
+
self.parent_conn.send(("exit", {}))
|
|
785
|
+
if join:
|
|
786
|
+
self.worker_proc.join()
|
|
787
|
+
|
|
788
|
+
def __del__(self):
|
|
789
|
+
# noinspection PyBroadException
|
|
790
|
+
try:
|
|
791
|
+
self.exit(join=False)
|
|
792
|
+
except Exception:
|
|
793
|
+
pass
|
|
794
|
+
else:
|
|
795
|
+
util.try_run(self.worker_proc.join)
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def _worker_proc_loop(
|
|
799
|
+
index: int,
|
|
800
|
+
parent_conn: mpConnection,
|
|
801
|
+
buffer_size: int,
|
|
802
|
+
map_seq: Optional[Callable],
|
|
803
|
+
map_seq_stream: Optional[Callable],
|
|
804
|
+
out_tensor_dict_template: TensorDict,
|
|
805
|
+
):
|
|
806
|
+
if sys.platform == "linux":
|
|
807
|
+
with open("/proc/self/comm", "w") as f:
|
|
808
|
+
f.write(f"PP worker {index}")
|
|
809
|
+
better_exchook.setup_all()
|
|
810
|
+
|
|
811
|
+
assert isinstance(buffer_size, int) and buffer_size > 0
|
|
812
|
+
assert isinstance(index, int)
|
|
813
|
+
assert isinstance(parent_conn, mpConnection)
|
|
814
|
+
|
|
815
|
+
cache: deque[TensorDict] = deque()
|
|
816
|
+
|
|
817
|
+
data_iter: Optional[Iterator[TensorDict]] = None
|
|
818
|
+
feeder_conn: Optional[mpConnection] = None
|
|
819
|
+
|
|
820
|
+
def _add_to_cache():
|
|
821
|
+
nonlocal data_iter
|
|
822
|
+
if data_iter is None:
|
|
823
|
+
return False
|
|
824
|
+
try:
|
|
825
|
+
seq = next(data_iter)
|
|
826
|
+
except StopIteration:
|
|
827
|
+
data_iter = None
|
|
828
|
+
return False
|
|
829
|
+
cache.append(seq)
|
|
830
|
+
return True
|
|
831
|
+
|
|
832
|
+
def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]:
|
|
833
|
+
assert isinstance(q, mpConnection)
|
|
834
|
+
|
|
835
|
+
while True:
|
|
836
|
+
try:
|
|
837
|
+
q.send(("get_seq", None))
|
|
838
|
+
seq_msg, item = q.recv()
|
|
839
|
+
except (BrokenPipeError, EOFError):
|
|
840
|
+
# queue is closed
|
|
841
|
+
break
|
|
842
|
+
assert seq_msg == "seq"
|
|
843
|
+
if item is None:
|
|
844
|
+
break
|
|
845
|
+
assert isinstance(item, TensorDict)
|
|
846
|
+
yield item
|
|
847
|
+
|
|
848
|
+
try:
|
|
849
|
+
while True:
|
|
850
|
+
while len(cache) < buffer_size and not parent_conn.poll():
|
|
851
|
+
if not _add_to_cache():
|
|
852
|
+
break
|
|
853
|
+
msg, kwargs = parent_conn.recv()
|
|
854
|
+
if msg == "exit":
|
|
855
|
+
break
|
|
856
|
+
elif msg == "get_seq":
|
|
857
|
+
if not cache:
|
|
858
|
+
_add_to_cache()
|
|
859
|
+
parent_conn.send(("seq", cache.popleft() if cache else None))
|
|
860
|
+
elif msg == "init_seq_order":
|
|
861
|
+
epoch = kwargs["epoch"]
|
|
862
|
+
if sys.platform == "linux":
|
|
863
|
+
with open("/proc/self/comm", "w") as f:
|
|
864
|
+
f.write(f"PP worker {index} ep {epoch}")
|
|
865
|
+
if feeder_conn is not None:
|
|
866
|
+
feeder_conn.close()
|
|
867
|
+
feeder_conn = kwargs["seq_pipe"]
|
|
868
|
+
data_iter = _build_mapping_iter(
|
|
869
|
+
_iter_pipe(feeder_conn),
|
|
870
|
+
epoch=epoch,
|
|
871
|
+
map_seq=map_seq,
|
|
872
|
+
map_seq_stream=map_seq_stream,
|
|
873
|
+
out_tensor_dict_template=out_tensor_dict_template,
|
|
874
|
+
rng=RandomState(kwargs["rng_seed"]),
|
|
875
|
+
seq_list_for_validation=kwargs["seq_list"],
|
|
876
|
+
)
|
|
877
|
+
assert isinstance(data_iter, Iterator)
|
|
878
|
+
cache.clear()
|
|
879
|
+
parent_conn.send(("init_seq_order", None))
|
|
880
|
+
else:
|
|
881
|
+
raise Exception(f"unknown msg {msg!r}")
|
|
882
|
+
except KeyboardInterrupt: # when parent dies
|
|
883
|
+
pass
|
|
884
|
+
except EOFError: # when parent dies
|
|
885
|
+
pass
|
|
886
|
+
finally:
|
|
887
|
+
if feeder_conn is not None:
|
|
888
|
+
feeder_conn.close()
|
|
889
|
+
parent_conn.close()
|
|
890
|
+
|
|
402
891
|
|
|
403
892
|
class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]):
|
|
404
893
|
"""
|