returnn 1.20250226.183415__py3-none-any.whl → 1.20250228.101938__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250226.183415
3
+ Version: 1.20250228.101938
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250226.183415'
2
- long_version = '1.20250226.183415+git.ba9d72e'
1
+ version = '1.20250228.101938'
2
+ long_version = '1.20250228.101938+git.c053cfd'
returnn/datasets/basic.py CHANGED
@@ -16,6 +16,7 @@ from threading import RLock
16
16
  from random import Random, random
17
17
  import sys
18
18
  import os
19
+ import math
19
20
  import numpy
20
21
  import functools
21
22
  import typing
@@ -937,28 +938,51 @@ class Dataset:
937
938
  else:
938
939
  # We don't know. So:
939
940
  # Some monotonic increasing function in [0,1] which never reaches 1.
940
- import math
941
+ return max(1.0e-10, (1 - 1 / ((seq_idx**0.5) / 100 + 1)) * 0.99)
941
942
 
942
- return max(1.0e-10, 1.0 - math.exp(-seq_idx * 1000))
943
-
944
- def get_complete_frac(self, seq_idx):
943
+ def get_complete_frac(self, sorted_seq_idx: int, *, allow_only_lr_suitable: bool = False) -> Optional[float]:
945
944
  """
946
- :param int seq_idx:
947
- :return: Returns a fraction (float in [0,1], always > 0) of how far we have advanced
948
- for this seq in the dataset.
949
- This does not have to be exact. This is only for the user.
950
- :rtype: float
945
+ Tries to calculate exactly how much of the current epoch is completed when
946
+ having processed seq ``sorted_seq_idx``.
947
+
948
+ ``sorted_seq_idx`` cannot be less than the seq index of the previously loaded seqs.
949
+
950
+ :param sorted_seq_idx: sorted seq idx
951
+ :param allow_only_lr_suitable: only return a value when that value is suitable/accurate enough
952
+ to base LR scheduling on it. If false, this function will return an approximative value
953
+ when the exact value cannot be calculated (due to unknown ``num_seqs``).
954
+ Approximative values can be appropriate for e.g. progress bars.
955
+ :return: continuous value in (0, 1] which represents how much of the current epoch
956
+ is completed after ``sorted_seq_idx``.
957
+ If ``allow_only_lr_suitable=True``, returns ``None`` if the value cannot be calculated such
958
+ that it is accurate enough for LR scheduling, and otherwises bases ``epoch_continuous`` on it
959
+ for any dynamic learning rate scheduling.
960
+ As ``sorted_seq_idx`` is monotonic, the return value is also guaranteed to be monotonic.
951
961
  """
952
962
  # noinspection PyBroadException
953
963
  try:
954
964
  num_seqs = self.num_seqs
955
965
  except Exception: # num_seqs not always available
966
+ if allow_only_lr_suitable:
967
+ return None
968
+
956
969
  # noinspection PyBroadException
957
970
  try:
958
971
  num_seqs = self.estimated_num_seqs
959
972
  except Exception: # also not always available
960
973
  num_seqs = None # ignore
961
- return self.generic_complete_frac(seq_idx, num_seqs)
974
+
975
+ if math.isinf(num_seqs):
976
+ if allow_only_lr_suitable:
977
+ # cannot compute meaningful complete_frac for infinite num_seqs
978
+ return None
979
+ else:
980
+ num_seqs = None
981
+
982
+ assert (
983
+ num_seqs is None or 0 <= sorted_seq_idx < num_seqs
984
+ ), f"{self}: invalid seq indices: 0 <= seq_idx ({sorted_seq_idx}) < num_seqs ({num_seqs}) violated"
985
+ return self.generic_complete_frac(sorted_seq_idx, num_seqs)
962
986
 
963
987
  @property
964
988
  def num_seqs(self) -> int:
@@ -1375,16 +1399,27 @@ class DatasetSeq:
1375
1399
  Encapsulates all data for one sequence.
1376
1400
  """
1377
1401
 
1378
- def __init__(self, seq_idx, features, targets=None, seq_tag=None):
1402
+ def __init__(
1403
+ self,
1404
+ seq_idx: int,
1405
+ features,
1406
+ *,
1407
+ targets=None,
1408
+ seq_tag: Optional[str] = None,
1409
+ complete_frac: Optional[float] = None,
1410
+ ):
1379
1411
  """
1380
- :param int seq_idx: sorted seq idx in the Dataset
1412
+ :param seq_idx: sorted seq idx in the Dataset
1381
1413
  :param numpy.ndarray|dict[str,numpy.ndarray] features: format 2d (time,feature) (float)
1382
1414
  :param dict[str,numpy.ndarray]|numpy.ndarray|None targets: name -> format 1d (time) (idx of output-feature)
1383
- :param str seq_tag: sequence name / tag
1415
+ :param seq_tag: sequence name / tag
1416
+ :param complete_frac: continuous value in (0, 1] which represents how much of the current epoch
1417
+ has been consumed when this seq is processed
1384
1418
  """
1385
1419
  assert isinstance(seq_idx, (int, numpy.integer))
1386
1420
  self.seq_idx = int(seq_idx)
1387
1421
  self.seq_tag = seq_tag or ("seq-%i" % seq_idx)
1422
+ self.complete_frac = complete_frac
1388
1423
  if not isinstance(features, dict):
1389
1424
  assert isinstance(features, numpy.ndarray)
1390
1425
  features = {"data": features}
@@ -228,6 +228,15 @@ class CachedDataset2(Dataset):
228
228
  keys.remove("data")
229
229
  return keys
230
230
 
231
+ def get_complete_frac(self, sorted_seq_idx, **kwargs):
232
+ """
233
+ :return: fractional completion value for the given sorted_seq_idx
234
+ """
235
+ seq = self._get_seq(sorted_seq_idx)
236
+ if seq is not None and seq.complete_frac is not None:
237
+ return seq.complete_frac
238
+ return super().get_complete_frac(sorted_seq_idx, **kwargs)
239
+
231
240
  def is_data_sparse(self, key):
232
241
  """
233
242
  :param str key: e.g. "data" or "classes"
@@ -601,7 +601,8 @@ def _worker_proc_loop(
601
601
  dataset.load_seqs(next_seq_idx, next_seq_idx + 1)
602
602
  seq_tag = dataset.get_tag(next_seq_idx)
603
603
  features = {data_key: dataset.get_data(next_seq_idx, data_key) for data_key in dataset.get_data_keys()}
604
- res = DatasetSeq(seq_idx=next_seq_idx, seq_tag=seq_tag, features=features)
604
+ complete_frac = dataset.get_complete_frac(next_seq_idx, allow_only_lr_suitable=True)
605
+ res = DatasetSeq(seq_idx=next_seq_idx, seq_tag=seq_tag, features=features, complete_frac=complete_frac)
605
606
  cache.append(res)
606
607
  next_seq_idx += 1
607
608
  return True
returnn/datasets/meta.py CHANGED
@@ -554,6 +554,12 @@ class MetaDataset(CachedDataset2):
554
554
  """
555
555
  return self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]
556
556
 
557
+ def get_complete_frac(self, sorted_seq_idx: int, **kwargs) -> Optional[float]:
558
+ """
559
+ :param sorted_seq_idx:
560
+ """
561
+ return self.datasets[self.default_dataset_key].get_complete_frac(sorted_seq_idx, **kwargs)
562
+
557
563
  def get_data_keys(self) -> List[str]:
558
564
  """data keys"""
559
565
  return sorted(self.data_keys)
@@ -75,6 +75,7 @@ class MultiProcDataset(CachedDataset2):
75
75
  self._seq_order_proc_parent_conn = None # type: Optional[mpConnection]
76
76
  self._seq_order_proc = None # type: Optional[mp.Process]
77
77
  self._worker_procs = None # type: Optional[List[mp.Process]]
78
+ self._cur_max_complete_frac: Optional[float] = None
78
79
 
79
80
  if _meta_info_cache:
80
81
  # This allows to skip the lazy init in self.initialize().
@@ -246,7 +247,8 @@ class MultiProcDataset(CachedDataset2):
246
247
  dataset.load_seqs(next_seq_idx, next_seq_idx + 1)
247
248
  seq_tag = dataset.get_tag(next_seq_idx)
248
249
  features = {data_key: dataset.get_data(next_seq_idx, data_key) for data_key in dataset.get_data_keys()}
249
- res = DatasetSeq(seq_idx=next_seq_idx, seq_tag=seq_tag, features=features)
250
+ complete_frac = dataset.get_complete_frac(next_seq_idx, allow_only_lr_suitable=True)
251
+ res = DatasetSeq(seq_idx=next_seq_idx, seq_tag=seq_tag, features=features, complete_frac=complete_frac)
250
252
  cache.append(res)
251
253
  next_seq_idx += 1
252
254
  return True
@@ -403,6 +405,7 @@ class MultiProcDataset(CachedDataset2):
403
405
  return True
404
406
 
405
407
  self._lazy_init()
408
+ self._cur_max_complete_frac = 0.0
406
409
 
407
410
  if self._sharding_method == "dedicated":
408
411
  for worker_conn in self._worker_parent_conns:
@@ -441,6 +444,12 @@ class MultiProcDataset(CachedDataset2):
441
444
  if data is None:
442
445
  return None
443
446
  assert isinstance(data, DatasetSeq)
447
+ # The complete_frac values from the subprocesses are not necessarily monotonic
448
+ # due to rounding errors in the sharding and such.
449
+ # We therefore fix them up here. This is valid due to monotonicity of `seq_idx`.
450
+ max_comp_frac = max(data.complete_frac, self._cur_max_complete_frac)
451
+ data.complete_frac = max_comp_frac
452
+ self._cur_max_complete_frac = max_comp_frac
444
453
  data.seq_idx = seq_idx
445
454
  return data
446
455
 
@@ -154,4 +154,4 @@ class NumpyDumpDataset(Dataset):
154
154
  def _add_cache_seq(self, seq_idx, features, targets):
155
155
  last_seq_idx = self._get_cache_last_seq_idx()
156
156
  assert seq_idx == last_seq_idx + 1
157
- self.cached_seqs += [DatasetSeq(seq_idx, features, targets)]
157
+ self.cached_seqs += [DatasetSeq(seq_idx, features, targets=targets)]
@@ -5,6 +5,7 @@ Provides :class:`PostprocessingDataset`.
5
5
  from __future__ import annotations
6
6
 
7
7
  from itertools import islice
8
+ import numpy
8
9
  from numpy.random import RandomState
9
10
  from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
10
11
 
@@ -57,6 +58,14 @@ class PostprocessingDataset(CachedDataset2):
57
58
  },
58
59
  }
59
60
 
61
+ The postprocessor functions operate on ``TensorDict``s, which have entries for
62
+ all data keys in the underlying dataset.
63
+
64
+ There may also be additional "meta" entries in the tensor dicts, like ``complete_frac``
65
+ and ``seq_tag``.
66
+ These should be copied over in a manner that is reasonable for the use case at hand and
67
+ ensures forwards compatibility as well as reasonably possible.
68
+
60
69
  The dataset itself does not support its own seq ordering and relies on the wrapped
61
70
  dataset for seq ordering instead. Specifying a ``seq_ordering`` other than ``default``
62
71
  results in an error.
@@ -155,7 +164,10 @@ class PostprocessingDataset(CachedDataset2):
155
164
  self._out_tensor_dict_template = self._in_tensor_dict_template.copy_template()
156
165
  self.labels = self._dataset.labels.copy()
157
166
  # update only after _out_tensor_dict_template has been created from _in_tensor_dict_template
158
- self._in_tensor_dict_template.update({"seq_tag": {"dims": (), "dtype": "string"}}, auto_convert=True)
167
+ self._in_tensor_dict_template.update(
168
+ {"complete_frac": {"dims": (), "dtype": "float32"}, "seq_tag": {"dims": (), "dtype": "string"}},
169
+ auto_convert=True,
170
+ )
159
171
  self.num_outputs = {
160
172
  k: (t.sparse_dim.size if t.sparse_dim else t.shape[-1] if len(t.shape) > 0 else 1, t.ndim)
161
173
  for k, t in self._out_tensor_dict_template.data.items()
@@ -222,6 +234,15 @@ class PostprocessingDataset(CachedDataset2):
222
234
  """:return: dtype of data entry `key`"""
223
235
  return self._out_tensor_dict_template.data[key].dtype
224
236
 
237
+ def get_total_num_seqs(self, *, fast=False):
238
+ """:return: total num seqs excluding partition_epoch"""
239
+ if self._map_seq_stream is not None:
240
+ raise util.OptionalNotImplementedError(
241
+ f"{self}: get_total_num_seqs not allowed when map_seq_stream is set."
242
+ )
243
+ assert self._dataset is not None
244
+ return self._dataset.get_total_num_seqs(fast=fast)
245
+
225
246
  def supports_sharding(self) -> bool:
226
247
  """:return: whether this dataset supports sharding"""
227
248
  assert self._dataset is not None
@@ -249,11 +270,12 @@ class PostprocessingDataset(CachedDataset2):
249
270
  assert loaded_seq_idx <= seq_idx, "_collect_single_seq must be done monotonically"
250
271
  if loaded_seq_idx != seq_idx:
251
272
  continue
252
- seq = DatasetSeq(
253
- features={k: t.raw_tensor for k, t in tensor_dict.data.items() if k != "seq_tag"},
254
- seq_idx=seq_idx,
255
- seq_tag=str(tensor_dict["seq_tag"].raw_tensor),
273
+ complete_frac = (
274
+ float(tensor_dict.data["complete_frac"].raw_tensor) if "complete_frac" in tensor_dict.data else None
256
275
  )
276
+ seq_tag = str(tensor_dict.data["seq_tag"].raw_tensor) if "seq_tag" in tensor_dict.data else f"seq-{seq_idx}"
277
+ features = {k: t.raw_tensor for k, t in tensor_dict.data.items() if k not in ["complete_frac", "seq_tag"]}
278
+ seq = DatasetSeq(complete_frac=complete_frac, features=features, seq_idx=seq_idx, seq_tag=seq_tag)
257
279
  return seq
258
280
 
259
281
  def _build_mapping_iter(self) -> Iterator[TensorDict]:
@@ -262,8 +284,20 @@ class PostprocessingDataset(CachedDataset2):
262
284
  """
263
285
 
264
286
  def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]:
287
+ last_complete_frac = 0.0
265
288
  for t_dict in inner:
266
- assert "seq_tag" in t_dict.data, "seq_tag dropped from TensorDict in postprocessing pipeline"
289
+ assert isinstance(t_dict, TensorDict), (
290
+ f"postprocessing mapper function must produce a {TensorDict.__name__}, "
291
+ f"but got a {type(t_dict).__name__}"
292
+ )
293
+ if "complete_frac" in t_dict.data: # sanity check complete_frac
294
+ complete_frac = float(t_dict.data["complete_frac"].raw_tensor)
295
+ assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
296
+ assert complete_frac >= last_complete_frac, (
297
+ "complete_frac must be monotonically increasing, "
298
+ f"but got {complete_frac} after {last_complete_frac}"
299
+ )
300
+ last_complete_frac = complete_frac
267
301
  for data_key, out_t in self._out_tensor_dict_template.data.items():
268
302
  in_t = t_dict.data[data_key]
269
303
  assert (
@@ -294,8 +328,14 @@ class PostprocessingDataset(CachedDataset2):
294
328
  tensor_dict = self._in_tensor_dict_template.copy_template()
295
329
  for data_key in data_keys:
296
330
  tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key)
297
- seq_tag_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
298
- tensor_dict.data["seq_tag"].raw_tensor = seq_tag_tensor
331
+
332
+ complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True)
333
+ comp_frac_raw_tensor = None
334
+ if complete_frac is not None:
335
+ comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32)
336
+ tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor
337
+ seq_tag_raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
338
+ tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor
299
339
 
300
340
  if self._map_seq is not None:
301
341
  tensor_dict = self._map_seq(
@@ -305,10 +345,16 @@ class PostprocessingDataset(CachedDataset2):
305
345
  tensor_dict, TensorDict
306
346
  ), f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
307
347
 
308
- # Re-adding the seq tag here causes no harm in case it's dropped since we don't
309
- # add/drop any segments w/ the non-iterator postprocessing function.
348
+ # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped
349
+ # since we don't add/drop any segments w/ the non-iterator postprocessing function.
350
+ if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None:
351
+ tensor_dict.data["complete_frac"] = Tensor(
352
+ "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor
353
+ )
310
354
  if "seq_tag" not in tensor_dict.data:
311
- tensor_dict.data["seq_tag"].raw_tensor = seq_tag_tensor
355
+ tensor_dict.data["seq_tag"] = Tensor(
356
+ "seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor
357
+ )
312
358
 
313
359
  if self._seq_list_for_validation is not None:
314
360
  seq_tag = self._seq_list_for_validation[seq_index]
@@ -366,7 +412,12 @@ class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]):
366
412
  seq_buffer = list(islice(iterator, self.num_seqs_per_bin))
367
413
  has_ended = False
368
414
  while True:
415
+ # Make sure to not reorder the monotonically increasing values for complete_frac
416
+ # so that the trainer can calculate the appropriate learning rates.
417
+ complete_frac_values = [tdict.data["complete_frac"].raw_tensor for tdict in seq_buffer]
369
418
  seq_buffer.sort(key=self._get_seq_len, reverse=is_down_phase)
419
+ for sorted_item, comp_frac in zip(seq_buffer, complete_frac_values):
420
+ sorted_item.data["complete_frac"].raw_tensor = comp_frac
370
421
 
371
422
  next_seq_buffer = []
372
423
 
@@ -83,7 +83,7 @@ class RawWavDataset(CachedDataset2):
83
83
  inputFeatures = inputFeatures.astype(np.float32)
84
84
  if outputFeatures is not None:
85
85
  outputFeatures = outputFeatures.astype(np.float32)
86
- return DatasetSeq(seq_idx, inputFeatures, outputFeatures)
86
+ return DatasetSeq(seq_idx, inputFeatures, targets=outputFeatures)
87
87
 
88
88
  def _get_num_outputs(self, num_outputs):
89
89
  """
@@ -504,7 +504,7 @@ class SprintDatasetBase(Dataset):
504
504
  assert seq_idx + 1 == self.next_seq_to_be_added
505
505
  self.cond.wait()
506
506
 
507
- self.added_data += [DatasetSeq(seq_idx, features, targets, seq_tag=segment_name)]
507
+ self.added_data += [DatasetSeq(seq_idx, features, targets=targets, seq_tag=segment_name)]
508
508
  self.cond.notify_all()
509
509
  return seq_idx
510
510
 
@@ -588,7 +588,7 @@ class SprintDatasetBase(Dataset):
588
588
  """
589
589
  self._complete_frac = frac
590
590
 
591
- def get_complete_frac(self, seq_idx):
591
+ def get_complete_frac(self, seq_idx, **kwargs):
592
592
  """
593
593
  :param int seq_idx:
594
594
  :rtype: float
@@ -349,7 +349,7 @@ class StereoHdfDataset(StereoDataset):
349
349
  elif targets.shape[1] == 1:
350
350
  targets = np.reshape(targets.astype(np.int32), (targets.shape[0],))
351
351
 
352
- return DatasetSeq(seq_idx, inputFeatures, targets)
352
+ return DatasetSeq(seq_idx, inputFeatures, targets=targets)
353
353
 
354
354
  @staticmethod
355
355
  def _normalizeVector(v, mean, variance):
@@ -438,4 +438,4 @@ class DatasetWithTimeContext(StereoHdfDataset):
438
438
  targets = None
439
439
  if "classes" in originalSeq.get_data_keys():
440
440
  targets = originalSeq.get_data("classes")
441
- return DatasetSeq(seq_idx, inputFeatures, targets)
441
+ return DatasetSeq(seq_idx, inputFeatures, targets=targets)
@@ -21,6 +21,7 @@ other PyTorch datasets more directly, including also HuggingFace datasets.
21
21
 
22
22
  from __future__ import annotations
23
23
  import bisect
24
+ import itertools
24
25
  from typing import Optional, Any, Sequence, Tuple, Union, List, Dict, Callable
25
26
  import sys
26
27
  from copy import deepcopy
@@ -65,6 +66,9 @@ def collate_batch(batch: List[Dict[str, numpy.ndarray]]) -> Dict[str, Union[torc
65
66
  if key in ("num_seqs", "epoch"):
66
67
  res[key] = batch[0][key] # it should always be the same
67
68
  continue
69
+ elif key == "complete_frac":
70
+ res[key] = max(sample[key] for sample in batch)
71
+ continue
68
72
  ls = [create_tensor(sample[key]) for sample in batch]
69
73
  if not ls:
70
74
  raise ValueError("batch is empty?")
@@ -122,7 +126,7 @@ class ChunkingIterDataPipe(torch.utils.data.IterDataPipe):
122
126
 
123
127
  if not chunking_data_keys:
124
128
  chunking_data_keys = list(data_dict.keys()) # use all if not configured separately
125
- chunking_data_key_black_list = ["seq_tag", "seq_idx", "num_seqs", "epoch"]
129
+ chunking_data_key_black_list = ["seq_tag", "seq_idx", "num_seqs", "epoch", "complete_frac"]
126
130
  for key in chunking_data_key_black_list:
127
131
  if key in chunking_data_keys:
128
132
  chunking_data_keys.remove(key)
@@ -269,8 +273,15 @@ class BatchingIterDataPipe(torch.utils.data.IterDataPipe):
269
273
  epoch = int(data_dict["epoch"])
270
274
  seq_idx = int(data_dict["seq_idx"])
271
275
  num_seqs = int(data_dict["num_seqs"]) # >=1 if known, otherwise -1
272
- epoch_continuous = (epoch - 1 + (seq_idx + 1) / num_seqs) if num_seqs > 0 else None
273
- return {"epoch": epoch, "seq_idx": seq_idx, "epoch_continuous": epoch_continuous, **get_fwd_compat_kwargs()}
276
+ complete_frac = float(data_dict["complete_frac"]) # >= 0 if known, otherwise -1
277
+ epoch_continuous = (
278
+ epoch - 1 + complete_frac
279
+ if complete_frac >= 0.0
280
+ else (epoch - 1 + (seq_idx + 1) / num_seqs)
281
+ if num_seqs > 0
282
+ else None
283
+ )
284
+ return {"epoch": epoch, "epoch_continuous": epoch_continuous, "seq_idx": seq_idx, **get_fwd_compat_kwargs()}
274
285
 
275
286
  def __iter__(self):
276
287
  """
@@ -455,6 +466,125 @@ class LenFilterDataPipe(torch.utils.data.IterDataPipe):
455
466
  raise Exception(f"{self.__class__.__name__}.__getitem__ not supported")
456
467
 
457
468
 
469
+ class ShufflingDataPipe(torch.utils.data.IterDataPipe):
470
+ """
471
+ Data pipe that is similar to ``torch.utils.data.datapipes.iter.Shuffler``,
472
+ but it will keep certain data keys of the batches in order while shuffling the rest.
473
+
474
+ Used for e.g. ``complete_frac`` and ``seq_idx``.
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ dataset: torch.utils.data.IterableDataset,
480
+ *,
481
+ buffer_size: int,
482
+ monotonic_data_keys: Sequence[str],
483
+ seed: Optional[int] = None,
484
+ ):
485
+ """
486
+ :param dataset: batches dataset to shuffle
487
+ :param buffer_size: buffer size for shuffling
488
+ :param monotonic_data_keys: data keys that will be excluded from shuffling/keep their order
489
+ :param seed: random seed
490
+ """
491
+ super().__init__()
492
+
493
+ self._dataset = dataset
494
+ self._buffer: List[List[Dict[str, Any]]] = []
495
+ self._next_buffer: List[List[Dict[str, Any]]] = []
496
+ assert buffer_size > 0
497
+ self._buffer_size = buffer_size
498
+ self._monotonic_data_keys = monotonic_data_keys
499
+ self._rng = numpy.random.RandomState()
500
+ self._seed = seed
501
+
502
+ def __iter__(self):
503
+ # The implementation is very similar to the PostprocessingDataset's combinator LaplaceOrdering.
504
+
505
+ data_iter = iter(self._dataset)
506
+
507
+ self._buffer.extend(itertools.islice(data_iter, self._buffer_size))
508
+ has_ended = False
509
+
510
+ while True:
511
+ # Make sure to not reorder the monotonic values from self._monotonic_data_keys.
512
+ # These can contain things like complete_frac, which should be kept in order.
513
+ ordered_data = {
514
+ key: [data_dict[key] for batch in self._buffer for data_dict in batch]
515
+ for key in self._monotonic_data_keys
516
+ }
517
+ self._rng.shuffle(self._buffer)
518
+ for key in self._monotonic_data_keys:
519
+ data_dicts = [data_dict for batch in self._buffer for data_dict in batch]
520
+ assert len(data_dicts) == len(ordered_data[key])
521
+ for ordered_value, data_dict in zip(ordered_data[key], data_dicts):
522
+ data_dict[key] = ordered_value
523
+
524
+ for item in self._buffer:
525
+ yield item
526
+
527
+ try:
528
+ if not has_ended:
529
+ self._next_buffer.append(next(data_iter))
530
+ except StopIteration:
531
+ has_ended = True
532
+
533
+ if len(self._buffer) < self._buffer_size:
534
+ assert has_ended and not self._next_buffer
535
+ break
536
+
537
+ self._buffer.clear()
538
+ self._buffer, self._next_buffer = self._next_buffer, self._buffer
539
+
540
+ def set_seed(self, seed: int) -> ShufflingDataPipe:
541
+ """
542
+ Sets the seed for the next invocation of ``__iter__``, for compatibility with
543
+ ``torch.utils.data.graph_settings.apply_random_seed``.
544
+ """
545
+ self._seed = seed % (2**32) # seed must be within [0, 2**32) for seeding RandomState
546
+ return self
547
+
548
+ def reset(self):
549
+ """resets the internal state of the data pipe"""
550
+ self._buffer.clear()
551
+ self._next_buffer.clear()
552
+ if self._seed is None:
553
+ self._seed = int(torch.empty((), dtype=torch.int32).random_().item())
554
+ self._rng.seed(self._seed)
555
+ self._seed = None
556
+
557
+ def __getstate__(self):
558
+ state = (
559
+ self._dataset,
560
+ self._buffer,
561
+ self._next_buffer,
562
+ self._buffer_size,
563
+ self._monotonic_data_keys,
564
+ self._rng.get_state(),
565
+ self._seed,
566
+ )
567
+ if torch.utils.data.IterDataPipe.getstate_hook is not None:
568
+ return torch.utils.data.IterDataPipe.getstate_hook(state)
569
+ return state
570
+
571
+ def __setstate__(self, state):
572
+ (
573
+ self._dataset,
574
+ self._buffer,
575
+ self._next_buffer,
576
+ self._buffer_size,
577
+ self._monotonic_data_keys,
578
+ rng_state,
579
+ self._seed,
580
+ ) = state
581
+ self._rng = numpy.random.RandomState()
582
+ self._rng.set_state(rng_state)
583
+
584
+ def __getitem__(self, index):
585
+ raise Exception(f"{self.__class__.__name__}.__getitem__ not supported")
586
+
587
+
458
588
  def create_data_loader_from_batches(
459
589
  batches_dataset: torch.utils.data.Dataset, loader_opts: Optional[Dict[str, Any]] = None
460
590
  ) -> torch.utils.data.DataLoader:
@@ -102,6 +102,7 @@ class ReturnnDatasetIterDataPipe(torch.utils.data.IterDataPipe):
102
102
 
103
103
  try:
104
104
  data_keys = self._dataset.get_data_keys()
105
+ last_complete_frac = -1
105
106
 
106
107
  seq_index = 0
107
108
  while self._dataset.is_less_than_num_seqs(seq_index):
@@ -109,11 +110,24 @@ class ReturnnDatasetIterDataPipe(torch.utils.data.IterDataPipe):
109
110
  data = {data_key: self._dataset.get_data(seq_index, data_key) for data_key in data_keys}
110
111
  data["seq_tag"] = str_to_numpy_array(self._dataset.get_tag(seq_index))
111
112
  data["seq_idx"] = numpy.array(seq_index)
112
- # It's slightly redundant to have num_seqs in each entry,
113
+
114
+ # It's slightly redundant to have the following data in each entry,
113
115
  # but it's difficult to pass this back to the main proc otherwise.
114
- data["num_seqs"] = num_seqs
115
- # epoch is also redundant, but that's the cleanest/simplest way to pass it on to BatchingIterDataPipe.
116
116
  data["epoch"] = epoch
117
+ data["num_seqs"] = num_seqs
118
+
119
+ complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True)
120
+ if complete_frac is not None:
121
+ assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}"
122
+ assert complete_frac >= last_complete_frac, (
123
+ "complete_frac must be monotonically increasing, "
124
+ f"but got {complete_frac} after {last_complete_frac}"
125
+ )
126
+ else:
127
+ complete_frac = -1
128
+ data["complete_frac"] = numpy.array(complete_frac, dtype=numpy.float32)
129
+ last_complete_frac = complete_frac
130
+
117
131
  yield data
118
132
  seq_index += 1
119
133
 
returnn/torch/engine.py CHANGED
@@ -399,6 +399,7 @@ class Engine(EngineBase):
399
399
  {k: int(util.prod(extern_data_raw[k].shape[:2])) for k in keys_w_seq_len},
400
400
  )
401
401
 
402
+ complete_frac = float(extern_data_raw["complete_frac"])
402
403
  num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
403
404
  report_prefix=report_prefix,
404
405
  extern_data_raw=extern_data_raw,
@@ -406,7 +407,13 @@ class Engine(EngineBase):
406
407
  prev_num_seqs=num_seqs,
407
408
  prev_last_seq_idx=last_seq_idx,
408
409
  )
409
- epoch_continuous = (self.epoch - 1 + (last_seq_idx + 1) / num_seqs) if num_seqs is not None else None
410
+ epoch_continuous = (
411
+ self.epoch - 1 + complete_frac
412
+ if complete_frac >= 0.0
413
+ else (self.epoch - 1 + (last_seq_idx + 1) / num_seqs)
414
+ if num_seqs is not None
415
+ else None
416
+ )
410
417
 
411
418
  # clear the gradients when every gradient accumulation loop starts
412
419
  if zero_grad_next_step:
@@ -777,7 +784,9 @@ class Engine(EngineBase):
777
784
  # Also note that we are likely using persistent multiprocessing data loader workers,
778
785
  # so calling torch.utils.data.graph_settings.apply_random_seed here in the main proc
779
786
  # will not have an effect then.
780
- batches_dataset = torch.utils.data.datapipes.iter.Shuffler(batches_dataset, **online_shuffle_batches)
787
+ batches_dataset = data_pipeline.ShufflingDataPipe(
788
+ batches_dataset, monotonic_data_keys=("complete_frac", "seq_idx"), **online_shuffle_batches
789
+ )
781
790
 
782
791
  loader_opts = self.config.typed_value("torch_dataloader_opts") or {}
783
792
  assert isinstance(loader_opts, dict), f"config torch_dataloader_opts, expected dict, got {type(loader_opts)}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250226.183415
3
+ Version: 1.20250228.101938
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=tasZ4y9DTXOoBq1n6RhxHj7GEEim3NIV3shYE_6qnzs,5215
1
+ returnn/PKG-INFO,sha256=I8nJH2i19lJSp03bggFS1YlTbOT-yFLg8yanKsDGZEk,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=i9lO16SJCJurcbJrmKJjUX_VLD7LAXplYmS6TPYAzTI,77
6
+ returnn/_setup_info_generated.py,sha256=J7wtu2Asd11qxOS3X2dv_AblIP3xvjshEattiHywgzQ,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -13,23 +13,23 @@ returnn/native_op.py,sha256=yqpE7SqBqXq77FCVnWMloUwadWlslEk-VzdK7FMpt_U,244411
13
13
  returnn/pretrain.py,sha256=MHiXJZqkQFmDVyaYsGpd_Acv20wxl7Pr6s6qJzAT2FI,22648
14
14
  returnn/datasets/__init__.py,sha256=PvDlfDOaaopIeUIt0OSvHD2eHZkdkyE-sjMXf35EH5U,390
15
15
  returnn/datasets/audio.py,sha256=Gmj7a08dnvYh7Z-G1TNapz42L50AIcDE9JeIZaO1s1M,23334
16
- returnn/datasets/basic.py,sha256=gLssy9J7nfwm1teWHEPoHPynWUWm1MBCpjqVUPZyZPA,70519
16
+ returnn/datasets/basic.py,sha256=EhgyOv9bGHY08rCTQpt1HN_vW3djP5RwJuxtbp53neM,72300
17
17
  returnn/datasets/bundle_file.py,sha256=KQNrS1MSf-4_idlK0c0KFwON-f5sEK0sWU15WpoMYpE,2380
18
18
  returnn/datasets/cached.py,sha256=DIRdWrxBmsZG8O_9eVxBO5mcdo4f5KU-Xb-4wVz59Io,25418
19
- returnn/datasets/cached2.py,sha256=STojLL2Ivvd0xMfZRlYgzsHKlikYKL-caZCIDCgc_9g,11773
20
- returnn/datasets/distrib_files.py,sha256=kyqIQILDPAO2TXr39hjslmDxIAc3pkY1UOoj8nuiFXo,27534
19
+ returnn/datasets/cached2.py,sha256=_6pza3IG68JexaExhj1ld3fP6pE7T-G804driJ9Z_qo,12141
20
+ returnn/datasets/distrib_files.py,sha256=_UlcrnaU1rA9v6D3H3X4dPhcA--09fNeVnWs9VNo0yg,27656
21
21
  returnn/datasets/generating.py,sha256=e2-SXcax7xQ4fkVW_Q5MgOLP6KlB7EQXJi_v64gVAWI,99805
22
22
  returnn/datasets/hdf.py,sha256=shif0aQqWWNJ0b6YnycpPjIVNsxjLrA41Y66-_SluGI,66993
23
23
  returnn/datasets/lm.py,sha256=h0IHUbze87njKrcD5eT1FRxde7elIio05n-BWiqmjFE,98805
24
24
  returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
25
- returnn/datasets/meta.py,sha256=wHquywF1C7-YWhcSFSAdDNc0nEHRjE-ks7YIEuDFMIE,94731
26
- returnn/datasets/multi_proc.py,sha256=7kppiXGiel824HM3GvHegluIxtiNAHafm-e6qh6W7YU,21948
25
+ returnn/datasets/meta.py,sha256=0wQzRzjShLSYNFoGo_MdR5IT8arxHr9gFjUlEqb2rbY,94969
26
+ returnn/datasets/multi_proc.py,sha256=aVjsLt2qjHnHOrEYCgIPCwNYE-f1fiGP6eZ8NGAr3A4,22583
27
27
  returnn/datasets/normalization_data.py,sha256=wOHrbO3612uWXpzLHHxksDw0qeVmQ42w7byBL9QMh9Q,14618
28
- returnn/datasets/numpy_dump.py,sha256=c2Xgn8cfWxvRNCBMraMCRuHsbmjVQ05sISlaYWIRlKg,5150
29
- returnn/datasets/postprocessing.py,sha256=G9QiMP3Qr0RmA1PL6fCXOUfa2e_iPzZq_Nfx_u7SNiI,19980
30
- returnn/datasets/raw_wav.py,sha256=UyC4dUARb9QL0KOGhYdt96R2N_61JvFSvcyHMT8vMnw,9136
31
- returnn/datasets/sprint.py,sha256=_RS3IFlI5sgkLmvPqvSirWCi7-yxys_m-EY232ec8sM,55446
32
- returnn/datasets/stereo.py,sha256=0Df0Omm4T4r60GEFa6sEvZdgkm6keEw-qcvIO4BoJew,17617
28
+ returnn/datasets/numpy_dump.py,sha256=wl8bKIKAlff2HPJPtuu5wBg3TLOf16d2wLVB4lLAwTM,5158
29
+ returnn/datasets/postprocessing.py,sha256=Jkad_KHMesdPFFg9NKi7U3sbPw-RzxfUX_vOgJsI7p0,23075
30
+ returnn/datasets/raw_wav.py,sha256=M7eTHp4CTtLQf3yPTiJY-mSJYgZNxkGV9IFN9J1dq_4,9144
31
+ returnn/datasets/sprint.py,sha256=YhhdNbBTuL_HCc3asgK3o6vgq5h5nMPH5nBFvsuwVjA,55464
32
+ returnn/datasets/stereo.py,sha256=PkowC91bZWihIYuIZgyGgPcNwgq5jBvyxxu1nER-VhM,17633
33
33
  returnn/datasets/text_dict.py,sha256=BPE73nh6-vtSLy3SiDf4dpFl9RJorE7oO6l5y2FU3MI,9965
34
34
  returnn/datasets/util/__init__.py,sha256=rEKhSD6fyhDiQF-x7dUQMwa29JZu72SDm7mYcCcLghY,52
35
35
  returnn/datasets/util/feature_extraction.py,sha256=axtXDb9wcNpOmyhmW3WJUj5xda29TKkKvOcGGvq7ExA,23923
@@ -207,13 +207,13 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
207
207
  returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=skFyutdVztxgTEk3HHJ8S83qRWbNpkNT8Tj16Ic0_hE,6981
210
- returnn/torch/engine.py,sha256=8BIpdcrpbJL9HrvCX-hISh-14zW9aSrHGvRWT9s0zOk,77103
210
+ returnn/torch/engine.py,sha256=sU9A96icaj65uaEkX4i4aUK3IrB2S19_Fb9_sueB_JE,77426
211
211
  returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
214
- returnn/torch/data/pipeline.py,sha256=cIdSVjQHP9gihdfy4Pk2yu1-w572Qk8L2v26RL503qU,23266
214
+ returnn/torch/data/pipeline.py,sha256=C0CAG_jk1oZwrPlW9WdRTxV9OvPztbqKjwKHnf3lhok,27886
215
215
  returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWIRGq_Z_nLM,888
216
- returnn/torch/data/returnn_dataset_wrapper.py,sha256=1Bw82-Ge_8m_DSDXZNqQ3zGDic2HQlp6jysELL0NVK0,7369
216
+ returnn/torch/data/returnn_dataset_wrapper.py,sha256=2CaDapzrlqahANuq-nyVAtv5ENHuM8A7okORwYJDisg,8006
217
217
  returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
218
218
  returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
219
219
  returnn/torch/frontend/_backend.py,sha256=TqyDWNP4XCvJNNGn8jyxaT8BOEjVE24QCUR3qsTIS3A,101242
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250226.183415.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250226.183415.dist-info/METADATA,sha256=tasZ4y9DTXOoBq1n6RhxHj7GEEim3NIV3shYE_6qnzs,5215
258
- returnn-1.20250226.183415.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
- returnn-1.20250226.183415.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250226.183415.dist-info/RECORD,,
256
+ returnn-1.20250228.101938.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250228.101938.dist-info/METADATA,sha256=I8nJH2i19lJSp03bggFS1YlTbOT-yFLg8yanKsDGZEk,5215
258
+ returnn-1.20250228.101938.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
+ returnn-1.20250228.101938.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250228.101938.dist-info/RECORD,,