returnn 1.20250227.110407__py3-none-any.whl → 1.20250228.104237__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250227.110407
3
+ Version: 1.20250228.104237
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250227.110407'
2
- long_version = '1.20250227.110407+git.f069d06'
1
+ version = '1.20250228.104237'
2
+ long_version = '1.20250228.104237+git.bd23951'
@@ -337,7 +337,13 @@ class BucketOrderingIterDataPipe(torch.utils.data.IterDataPipe):
337
337
  """
338
338
 
339
339
  def __init__(
340
- self, dataset: torch.utils.data.IterableDataset, *, buckets: Sequence[Tuple[int, int]], length_key: str
340
+ self,
341
+ dataset: torch.utils.data.IterableDataset,
342
+ *,
343
+ buckets: Sequence[Tuple[int, int]],
344
+ length_key: str,
345
+ random_bucket_prob: float = 0.0,
346
+ seed: Optional[int] = None,
341
347
  ):
342
348
  """
343
349
  :param dataset: dataset to apply bucket batching to
@@ -345,9 +351,17 @@ class BucketOrderingIterDataPipe(torch.utils.data.IterDataPipe):
345
351
  Segments longer than the largest size limit configured in the buckets are dropped. To avoid dropping
346
352
  any segments make sure your largest bucket allows segments larger than your longest training segment.
347
353
  :param length_key: data key to take as length measure
354
+ :param random_bucket_prob: Probability of putting a segment not into the best-fitting bucket, but into
355
+ a randomly chosen still-fitting bucket.
356
+ This increases seq length variation within the buckets at the cost of slighly more padding.
357
+ :param seed: random seed
348
358
  """
349
359
  self._dataset = dataset
350
360
  self._length_key = length_key
361
+ assert random_bucket_prob >= 0.0
362
+ self._random_bucket_prob = random_bucket_prob
363
+ self._rng = numpy.random.RandomState()
364
+ self._seed = seed % (2**32) if seed is not None else None
351
365
 
352
366
  assert buckets, "empty bucket batching configuration"
353
367
  if not all(size > 0 and max_seqs > 0 for size, max_seqs in buckets):
@@ -367,6 +381,12 @@ class BucketOrderingIterDataPipe(torch.utils.data.IterDataPipe):
367
381
  if bucket_idx >= len(self._max_seq_lens):
368
382
  # seg is too long, drop it
369
383
  continue
384
+ if (
385
+ self._random_bucket_prob > 0.0
386
+ and bucket_idx < len(self._max_seq_lens) - 1
387
+ and self._rng.rand() < self._random_bucket_prob
388
+ ):
389
+ bucket_idx = self._rng.randint(bucket_idx, len(self._max_bucket_sizes))
370
390
  buckets[bucket_idx].append(data_dict)
371
391
  if len(buckets[bucket_idx]) >= self._max_bucket_sizes[bucket_idx]:
372
392
  yield buckets[bucket_idx]
@@ -383,6 +403,21 @@ class BucketOrderingIterDataPipe(torch.utils.data.IterDataPipe):
383
403
  def __getitem__(self, index):
384
404
  raise Exception(f"{self.__class__.__name__}.__getitem__ is not supported")
385
405
 
406
+ def set_seed(self, seed: int) -> BucketOrderingIterDataPipe:
407
+ """
408
+ Sets the seed for the next invocation of ``__iter__``, for compatibility with
409
+ ``torch.utils.data.graph_settings.apply_random_seed``.
410
+ """
411
+ self._seed = seed % (2**32) # seed must be within [0, 2**32) for seeding RandomState
412
+ return self
413
+
414
+ def reset(self):
415
+ """resets the internal state of the data pipe"""
416
+ if self._seed is None:
417
+ self._seed = int(2**31 + torch.empty((), dtype=torch.int32).random_().item())
418
+ self._rng.seed(self._seed)
419
+ self._seed = None
420
+
386
421
 
387
422
  def get_batching_iterable_dataset_from_config(
388
423
  *, dataset: torch.utils.data.IterableDataset, config: Config, train: bool
@@ -497,7 +532,7 @@ class ShufflingDataPipe(torch.utils.data.IterDataPipe):
497
532
  self._buffer_size = buffer_size
498
533
  self._monotonic_data_keys = monotonic_data_keys
499
534
  self._rng = numpy.random.RandomState()
500
- self._seed = seed
535
+ self._seed = seed % (2**32) if seed is not None else None
501
536
 
502
537
  def __iter__(self):
503
538
  # The implementation is very similar to the PostprocessingDataset's combinator LaplaceOrdering.
@@ -547,8 +582,10 @@ class ShufflingDataPipe(torch.utils.data.IterDataPipe):
547
582
 
548
583
  def reset(self):
549
584
  """resets the internal state of the data pipe"""
585
+ self._buffer.clear()
586
+ self._next_buffer.clear()
550
587
  if self._seed is None:
551
- self._seed = int(torch.empty((), dtype=torch.int32).random_().item())
588
+ self._seed = int(2**31 + torch.empty((), dtype=torch.int32).random_().item())
552
589
  self._rng.seed(self._seed)
553
590
  self._seed = None
554
591
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250227.110407
3
+ Version: 1.20250228.104237
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=7HDw-iONH2fjzNktp78WrdUg15TWTFaD-8VGJITrmF0,5215
1
+ returnn/PKG-INFO,sha256=WyCpdpBjUpCVa51XYBBPnFFDZgy5pufGhH0PkInknok,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=OM77n2pKDjDW1JMMarYcFEKDu0Dw8yhE2fdmNGkW-Lk,77
6
+ returnn/_setup_info_generated.py,sha256=AsveviXvDuDQsTAWwDCtvPC3oMGwJWCILWlydtbAeak,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -211,7 +211,7 @@ returnn/torch/engine.py,sha256=sU9A96icaj65uaEkX4i4aUK3IrB2S19_Fb9_sueB_JE,77426
211
211
  returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
214
- returnn/torch/data/pipeline.py,sha256=mwbvocYe8dOxRpbGOyH6S1QO_egwfn9FGFG-NGGyOsA,27823
214
+ returnn/torch/data/pipeline.py,sha256=mA6R1QU9vvRmfaUBvdqI9jQeIB3O-01ODcpmXs1SZ-w,29458
215
215
  returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWIRGq_Z_nLM,888
216
216
  returnn/torch/data/returnn_dataset_wrapper.py,sha256=2CaDapzrlqahANuq-nyVAtv5ENHuM8A7okORwYJDisg,8006
217
217
  returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250227.110407.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250227.110407.dist-info/METADATA,sha256=7HDw-iONH2fjzNktp78WrdUg15TWTFaD-8VGJITrmF0,5215
258
- returnn-1.20250227.110407.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
- returnn-1.20250227.110407.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250227.110407.dist-info/RECORD,,
256
+ returnn-1.20250228.104237.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250228.104237.dist-info/METADATA,sha256=WyCpdpBjUpCVa51XYBBPnFFDZgy5pufGhH0PkInknok,5215
258
+ returnn-1.20250228.104237.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
+ returnn-1.20250228.104237.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250228.104237.dist-info/RECORD,,