returnn 1.20251006.114241__py3-none-any.whl → 1.20251007.115327__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20251006.114241
3
+ Version: 1.20251007.115327
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20251006.114241'
2
- long_version = '1.20251006.114241+git.7745ba7'
1
+ version = '1.20251007.115327'
2
+ long_version = '1.20251007.115327+git.70a1d5d'
returnn/datasets/basic.py CHANGED
@@ -19,6 +19,7 @@ import os
19
19
  import math
20
20
  import numpy
21
21
  import functools
22
+ import types
22
23
  from typing import TYPE_CHECKING, Optional, Any, Set, Tuple, Union, Type, Dict, Sequence, List, Callable
23
24
 
24
25
  from returnn.log import log
@@ -154,7 +155,7 @@ class Dataset:
154
155
  self.seq_tags_filter = set(self._load_seq_list_file(seq_list_filter_file)) if seq_list_filter_file else None
155
156
  self.unique_seq_tags = unique_seq_tags
156
157
  self._seq_order_seq_lens_file = seq_order_seq_lens_file
157
- self._seq_order_seq_lens_by_idx = None
158
+ self._seq_order_seq_lens_by_idx: Optional[Sequence[Union[int, float]]] = None
158
159
  # There is probably no use case for combining the two, so avoid potential misconfiguration.
159
160
  assert self.partition_epoch == 1 or self.repeat_epoch == 1, (
160
161
  "Combining partition_epoch and repeat_epoch is prohibited."
@@ -486,12 +487,8 @@ class Dataset:
486
487
  """
487
488
  raise NotImplementedError
488
489
 
489
- def _get_seq_order_seq_lens_by_idx(self, seq_idx):
490
- """
491
- :param int seq_idx:
492
- :rtype: int
493
- """
494
- if not self._seq_order_seq_lens_by_idx:
490
+ def _get_seq_order_seq_lens_by_idx(self, seq_idx: int) -> Union[int, float]:
491
+ if self._seq_order_seq_lens_by_idx is None:
495
492
  assert self._seq_order_seq_lens_file
496
493
  if self._seq_order_seq_lens_file.endswith(".gz"):
497
494
  import gzip
@@ -502,11 +499,12 @@ class Dataset:
502
499
  seq_lens = eval(raw)
503
500
  assert isinstance(seq_lens, dict)
504
501
  all_tags = self.get_all_tags()
505
- self._seq_order_seq_lens_by_idx = [seq_lens[tag] for tag in all_tags]
502
+ self._seq_order_seq_lens_by_idx = numpy.array([seq_lens[tag] for tag in all_tags])
503
+ self._get_seq_order_seq_lens_by_idx = self._seq_order_seq_lens_by_idx.__getitem__ # faster
506
504
  return self._seq_order_seq_lens_by_idx[seq_idx]
507
505
 
508
506
  def get_seq_order_for_epoch(
509
- self, epoch: Optional[int], num_seqs: int, get_seq_len: Optional[Callable[[int], int]] = None
507
+ self, epoch: Optional[int], num_seqs: int, get_seq_len: Optional[Callable[[int], Union[int, float]]] = None
510
508
  ) -> Sequence[int]:
511
509
  """
512
510
  Returns the order of the given epoch.
@@ -515,7 +513,7 @@ class Dataset:
515
513
 
516
514
  :param epoch: for 'random', this determines the random seed
517
515
  :param num_seqs:
518
- :param get_seq_len: function (originalSeqIdx: int) -> int
516
+ :param get_seq_len: function (originalSeqIdx: int) -> int|float
519
517
  :return: the order for the given epoch. such that seq_idx -> underlying idx
520
518
  """
521
519
  if epoch is None:
@@ -561,8 +559,9 @@ class Dataset:
561
559
  seq_index = range(num_seqs - 1, -1, -1) # type: Union[range, Sequence[int]]
562
560
  elif seq_ordering_method in ["sorted", "sorted_reverse"]:
563
561
  assert get_seq_len
564
- reverse = -1 if seq_ordering_method == "sorted_reverse" else 1
565
- seq_lens = [reverse * get_seq_len(i) for i in range(num_seqs)]
562
+ seq_lens = _get_seq_len_as_array(get_seq_len, num_seqs)
563
+ if seq_ordering_method == "sorted_reverse":
564
+ seq_lens = -seq_lens
566
565
  seq_index = numpy.argsort(seq_lens, kind="stable")
567
566
  elif seq_ordering_method == "random" or seq_ordering_method.startswith("random:"):
568
567
  tmp = seq_ordering_method.split(":", 1)
@@ -628,7 +627,7 @@ class Dataset:
628
627
  nth = 1
629
628
  else:
630
629
  nth = int(tmp[1])
631
- seq_lens = numpy.array([get_seq_len(i) for i in range(num_seqs)])
630
+ seq_lens = _get_seq_len_as_array(get_seq_len, num_seqs)
632
631
  rnd_seed = self._get_random_seed_for_epoch(epoch=epoch, num_epochs_fixed=nth)
633
632
  random_generator = numpy.random.RandomState(rnd_seed)
634
633
  seq_index = random_generator.permutation(num_seqs) # type: Union[numpy.ndarray, List[int]]
@@ -1757,3 +1756,19 @@ def set_config_extern_data_from_dataset(config, dataset):
1757
1756
  "extern_data",
1758
1757
  {key: _data_kwargs_from_dataset_key(dataset=dataset, key=key) for key in dataset.get_data_keys()},
1759
1758
  )
1759
+
1760
+
1761
+ def _get_seq_len_as_array(get_seq_len: Callable[[int], Union[int, float]], num_seqs: int) -> numpy.ndarray:
1762
+ if num_seqs == 0:
1763
+ return numpy.zeros((0,), dtype=numpy.int32)
1764
+ if isinstance(get_seq_len, (types.BuiltinMethodType, types.MethodWrapperType, types.MethodType)):
1765
+ # Call it once. This might trigger some caching.
1766
+ get_seq_len(0)
1767
+ # Get it again. This might now get us a different (cached) function, e.g. array.__getitem__.
1768
+ get_seq_len = getattr(get_seq_len.__self__, get_seq_len.__name__)
1769
+ assert isinstance(get_seq_len, (types.BuiltinMethodType, types.MethodWrapperType, types.MethodType))
1770
+ obj = get_seq_len.__self__
1771
+ if isinstance(obj, numpy.ndarray) and get_seq_len.__name__ == "__getitem__":
1772
+ assert obj.shape == (num_seqs,)
1773
+ return obj
1774
+ return numpy.array([get_seq_len(i) for i in range(num_seqs)])
@@ -135,7 +135,7 @@ class DistributeFilesDataset(CachedDataset2):
135
135
  def __init__(
136
136
  self,
137
137
  *,
138
- files: Union[List[FileTree], os.PathLike],
138
+ files: Union[List[FileTree], os.PathLike, Callable[[], List[FileTree]]],
139
139
  get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]],
140
140
  preload_next_n_sub_epochs: int = 1,
141
141
  buffer_size: int = 1,
@@ -151,6 +151,7 @@ class DistributeFilesDataset(CachedDataset2):
151
151
  can also be specified as a path to a .txt file containing one file per line,
152
152
  or a python file containing the repr of a list of arbitrarily nested python objects,
153
153
  or a JSON file containing a list of arbitarily nested (JSON) objects.
154
+ It can also be a callable which returns such a list.
154
155
  :param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files
155
156
  :param preload_next_n_sub_epochs: how many sub epoch datasets to preload
156
157
  :param buffer_size: buffer size for each worker, number of seqs to prefetch
@@ -244,6 +245,11 @@ class DistributeFilesDataset(CachedDataset2):
244
245
  return
245
246
  if isinstance(self.files, list):
246
247
  self._files = self.files
248
+ elif callable(self.files):
249
+ self._files = self.files()
250
+ assert isinstance(self._files, list), (
251
+ f"{self}: callable files {self.files} must return a list, got {type(self._files)}"
252
+ )
247
253
  elif isinstance(self.files, (str, os.PathLike)):
248
254
  _, ext = os.path.splitext(self.files)
249
255
  assert ext, f"{self}: no file extension on file list file {self.files}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20251006.114241
3
+ Version: 1.20251007.115327
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=khjeqHtAYC68kPEh0ltnHLB14E2_UOObt3MGLvgeMTY,5215
1
+ returnn/PKG-INFO,sha256=nPg5FDephK9Q9lpTYYnNiyWNuWX4iqkbzK4OJytg_64,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=MFwUI_cce-3opoV_eKfpRd891cRwBo6RMbTGriz-HTg,77
6
+ returnn/_setup_info_generated.py,sha256=3Oo-LSdu1OoWfV1GeXf8IFcg75AqFRPpH9erSCeRGPA,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -13,11 +13,11 @@ returnn/native_op.py,sha256=4_NnvfNxsM8GE_FsD6yOg6PZegqIdtJ3Sl1GdBWmFvg,244424
13
13
  returnn/pretrain.py,sha256=MHiXJZqkQFmDVyaYsGpd_Acv20wxl7Pr6s6qJzAT2FI,22648
14
14
  returnn/datasets/__init__.py,sha256=PvDlfDOaaopIeUIt0OSvHD2eHZkdkyE-sjMXf35EH5U,390
15
15
  returnn/datasets/audio.py,sha256=Gmj7a08dnvYh7Z-G1TNapz42L50AIcDE9JeIZaO1s1M,23334
16
- returnn/datasets/basic.py,sha256=_42fQztTZq7jNQrWdFBwulB1bNta17LOTyrD8XJ-7_E,73089
16
+ returnn/datasets/basic.py,sha256=e-hrhJjLq01JT3dYxeeRhw4z1DQC5lxglpw4RscD57U,74136
17
17
  returnn/datasets/bundle_file.py,sha256=KQNrS1MSf-4_idlK0c0KFwON-f5sEK0sWU15WpoMYpE,2380
18
18
  returnn/datasets/cached.py,sha256=RyefRjSDdp-HveK-2vLy2C6BIHcpqQ_lNvUKlIa4QAI,25412
19
19
  returnn/datasets/cached2.py,sha256=oJOq2lWRQpxm6kyUKW1w5qZBd4kdKEpwM7KY_QnXbq4,11922
20
- returnn/datasets/distrib_files.py,sha256=-WNVhtvdJFP3L9Meh33oTSYc0FJSvF40mJ5UI_vJbSE,30233
20
+ returnn/datasets/distrib_files.py,sha256=srTieLP02kCepAwZ6Y9p20cqB8nAlVJWbSAoOPna9ik,30567
21
21
  returnn/datasets/generating.py,sha256=Qb7V94N_GfL2pZPxWS5PmzszoVXXKzuUmsHuW3dmVbc,99556
22
22
  returnn/datasets/hdf.py,sha256=v5sjBenURR9Z-g7AQ9tsL84yDSye5RtbLpym3M6HSDE,67833
23
23
  returnn/datasets/lm.py,sha256=rQ3jV43lSnlGkKu7m5jTTH7aK0BOMXQocsHfJ8OGec8,99950
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20251006.114241.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20251006.114241.dist-info/METADATA,sha256=khjeqHtAYC68kPEh0ltnHLB14E2_UOObt3MGLvgeMTY,5215
258
- returnn-1.20251006.114241.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20251006.114241.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20251006.114241.dist-info/RECORD,,
256
+ returnn-1.20251007.115327.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20251007.115327.dist-info/METADATA,sha256=nPg5FDephK9Q9lpTYYnNiyWNuWX4iqkbzK4OJytg_64,5215
258
+ returnn-1.20251007.115327.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20251007.115327.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20251007.115327.dist-info/RECORD,,