returnn 1.20251027.117__py3-none-any.whl → 1.20251027.232712__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20251027.117
3
+ Version: 1.20251027.232712
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20251027.000117'
2
- long_version = '1.20251027.000117+git.f3e7971'
1
+ version = '1.20251027.232712'
2
+ long_version = '1.20251027.232712+git.d3f28ed'
returnn/config.py CHANGED
@@ -801,7 +801,7 @@ class SubProcCopyGlobalConfigPreInitFunc:
801
801
  from returnn.log import log
802
802
  from returnn import __old_mod_loader__
803
803
 
804
- better_exchook.install()
804
+ better_exchook.setup_all()
805
805
  __old_mod_loader__.disable_lazy_mod_loads()
806
806
 
807
807
  if self.global_config:
@@ -1164,11 +1164,9 @@ class StaticDataset(CachedDataset2):
1164
1164
  """supports sorting"""
1165
1165
  return True
1166
1166
 
1167
- def _collect_single_seq(self, seq_idx):
1168
- """
1169
- :param int seq_idx:
1170
- :rtype: DatasetSeq
1171
- """
1167
+ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
1168
+ if seq_idx >= len(self._seq_order):
1169
+ return None
1172
1170
  corpus_seq_idx = self._seq_order[seq_idx]
1173
1171
  data = self.data[corpus_seq_idx]
1174
1172
  return DatasetSeq(
returnn/datasets/meta.py CHANGED
@@ -964,7 +964,6 @@ class CombinedDataset(CachedDataset2):
964
964
  self.dataset_keys = set([m[0] for m in data_map.keys()]) # type: typing.Set[str]
965
965
  self.dataset_idx2key_map = dict(enumerate(sorted(self.dataset_keys))) # idx -> dataset-key
966
966
  self.data_keys = set(data_map.values()) # type: typing.Set[str]
967
- assert "data" in self.data_keys
968
967
  self.target_list = sorted(self.data_keys - {"data"})
969
968
 
970
969
  # Build target lookup table that maps from dataset_key and data_key (data key used by CombinedDataset)
@@ -994,8 +993,7 @@ class CombinedDataset(CachedDataset2):
994
993
  if data_dims:
995
994
  data_dims = convert_data_dims(data_dims)
996
995
  self.data_dims = data_dims
997
- assert "data" in data_dims
998
- for key in self.target_list:
996
+ for key in self.data_keys:
999
997
  assert key in data_dims
1000
998
  else:
1001
999
  self.data_dims = {}
@@ -1009,7 +1007,7 @@ class CombinedDataset(CachedDataset2):
1009
1007
  if dataset_data_key in dataset.labels:
1010
1008
  self.labels[data_key] = dataset.labels[dataset_data_key]
1011
1009
 
1012
- self.num_inputs = self.data_dims["data"][0]
1010
+ self.num_inputs = self.data_dims["data"][0] if "data" in self.data_dims else 0
1013
1011
  self.num_outputs = self.data_dims
1014
1012
 
1015
1013
  self.data_dtypes = {
@@ -1019,6 +1017,9 @@ class CombinedDataset(CachedDataset2):
1019
1017
 
1020
1018
  self.dataset_seq_idx_boundaries: Optional[List[int]] = None
1021
1019
  self.dataset_sorted_seq_idx_list: Optional[List[Tuple[int, int]]] = None
1020
+ self._sub_dataset_cur_loaded_seq_range: Optional[List[Tuple[int, int]]] = None
1021
+ # The usage is about the seqs already covered in dataset_sorted_seq_idx_list,
1022
+ # in case we dynamically build up this list.
1022
1023
  self.used_num_seqs_per_subset: Optional[List[int]] = None
1023
1024
 
1024
1025
  def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
@@ -1030,7 +1031,7 @@ class CombinedDataset(CachedDataset2):
1030
1031
  """
1031
1032
 
1032
1033
  assert seq_list is None and seq_order is None, "seq_list and seq_order not supported for %s" % self.__class__
1033
- need_reinit = self.epoch is None or self.epoch != epoch
1034
+ need_reinit = self.epoch is None or self.epoch != epoch or self.expected_load_seq_start > 0
1034
1035
  num_seqs_saved = self._num_seqs
1035
1036
  super(CombinedDataset, self).init_seq_order(
1036
1037
  epoch=epoch, seq_list=seq_list, seq_order=seq_order
@@ -1047,13 +1048,15 @@ class CombinedDataset(CachedDataset2):
1047
1048
  for dataset in self.datasets.values():
1048
1049
  dataset.init_seq_order(epoch=epoch)
1049
1050
 
1051
+ self._sub_dataset_cur_loaded_seq_range = [(0, 0)] * len(self.datasets)
1052
+
1050
1053
  # noinspection PyBroadException
1051
1054
  try:
1052
1055
  total_num_seqs = sum([self.datasets[k].num_seqs for k in sorted(self.datasets.keys())])
1053
1056
  except Exception:
1054
1057
  total_num_seqs = None
1055
1058
 
1056
- if total_num_seqs is not None:
1059
+ if total_num_seqs is not None and self.seq_ordering != "interleave":
1057
1060
  self.dataset_seq_idx_boundaries = self._create_dataset_seq_idx_boundaries()
1058
1061
 
1059
1062
  if self.sampling_sizes:
@@ -1090,7 +1093,7 @@ class CombinedDataset(CachedDataset2):
1090
1093
 
1091
1094
  # Re-initialize sequence orders of sub-datasets with created sequence list.
1092
1095
  self.used_num_seqs_per_subset = []
1093
- for dataset_idx, dataset_key in self.dataset_idx2key_map.items():
1096
+ for dataset_idx, dataset_key in sorted(self.dataset_idx2key_map.items()):
1094
1097
  assert self.datasets[dataset_key].have_corpus_seq_idx()
1095
1098
  self.datasets[dataset_key].init_seq_order(epoch=epoch, seq_order=seq_order_subdatasets[dataset_idx])
1096
1099
  self.used_num_seqs_per_subset.append(len(seq_order_subdatasets[dataset_idx]))
@@ -1098,6 +1101,11 @@ class CombinedDataset(CachedDataset2):
1098
1101
  else:
1099
1102
  self.dataset_sorted_seq_idx_list = [] # We will fill this as we go
1100
1103
  self.used_num_seqs_per_subset = [0] * len(self.datasets)
1104
+ self._num_seqs = total_num_seqs
1105
+
1106
+ # These are currently not supported/implemented.
1107
+ # All of these should just be done in the sub-datasets directly.
1108
+ assert self.partition_epoch == 1 and self.repeat_epoch == 1 and self._num_shards == 1
1101
1109
 
1102
1110
  return True
1103
1111
 
@@ -1236,13 +1244,30 @@ class CombinedDataset(CachedDataset2):
1236
1244
 
1237
1245
  return dataset.get_estimated_seq_length(dataset_seq_idx)
1238
1246
 
1239
- def _expand_dataset_sec_idxs(self, num_values):
1247
+ def _sub_dataset_make_cur_loaded(self, dataset_idx: int) -> bool:
1248
+ # Cur meaning for the next sequence to be added to dataset_sorted_seq_idx_list.
1249
+ seq_idx = self.used_num_seqs_per_subset[dataset_idx]
1250
+ cur_start, cur_end = self._sub_dataset_cur_loaded_seq_range[dataset_idx]
1251
+ if seq_idx >= cur_end:
1252
+ self._sub_dataset_load_seqs(dataset_idx, cur_start, seq_idx + 1)
1253
+ return True
1254
+ elif seq_idx < cur_start:
1255
+ return False
1256
+ else:
1257
+ return True
1258
+
1259
+ def _expand_dataset_seq_idxs(self, num_values: int) -> bool:
1240
1260
  """
1241
- :param int num_values: Add num_values entries to the dataset-segment-idx mapping table
1242
- :return: something?
1243
- :rtype: bool
1261
+ Try to extend dataset_sorted_seq_idx_list.
1262
+ We expect that we have reached the end of it.
1263
+
1264
+ :param num_values: Add num_values entries to the dataset-segment-idx mapping table
1265
+ :return: whether we added num_values entries
1244
1266
  """
1245
- for i in range(num_values):
1267
+ for _ in range(num_values):
1268
+ for j in range(len(self.datasets)):
1269
+ self._sub_dataset_make_cur_loaded(j)
1270
+
1246
1271
  if self.seq_ordering == "default": # i.e. in order
1247
1272
  dataset_idx = 0
1248
1273
  while dataset_idx < len(self.datasets):
@@ -1265,6 +1290,32 @@ class CombinedDataset(CachedDataset2):
1265
1290
  else:
1266
1291
  return False # No dataset has remaining data
1267
1292
 
1293
+ elif self.seq_ordering == "interleave":
1294
+ complete_fracs_and_ds_idx = [
1295
+ (
1296
+ self.datasets[self.dataset_idx2key_map[j]].get_complete_frac(
1297
+ self.used_num_seqs_per_subset[j] - 1, allow_only_lr_suitable=True
1298
+ )
1299
+ if self.used_num_seqs_per_subset[j] > 0
1300
+ else 0.0,
1301
+ j,
1302
+ )
1303
+ for j in range(len(self.datasets))
1304
+ ]
1305
+ assert all(frac is not None for frac, _ in complete_fracs_and_ds_idx), (
1306
+ f"{self}: Datasets must provide complete frac for interleave,"
1307
+ f" got {complete_fracs_and_ds_idx}, dataset idx2key map {self.dataset_idx2key_map}"
1308
+ )
1309
+ # Sort by complete frac, i.e. datasets with the lowest complete frac first.
1310
+ complete_fracs_and_ds_idx.sort()
1311
+ for complete_frac, dataset_idx in complete_fracs_and_ds_idx:
1312
+ if self.datasets[self.dataset_idx2key_map[dataset_idx]].is_less_than_num_seqs(
1313
+ self.used_num_seqs_per_subset[dataset_idx]
1314
+ ):
1315
+ break
1316
+ else:
1317
+ return False # No dataset has remaining data
1318
+
1268
1319
  elif self.seq_ordering == "random_dataset":
1269
1320
  while True:
1270
1321
  # Build probability table
@@ -1323,19 +1374,23 @@ class CombinedDataset(CachedDataset2):
1323
1374
  def _load_seqs(self, start, end):
1324
1375
  # If the segment order is not yet known, fix the next few segments
1325
1376
  if end > len(self.dataset_sorted_seq_idx_list):
1326
- self._expand_dataset_sec_idxs(end - len(self.dataset_sorted_seq_idx_list))
1377
+ self._expand_dataset_seq_idxs(end - len(self.dataset_sorted_seq_idx_list))
1327
1378
 
1328
1379
  requested_seqs = self.dataset_sorted_seq_idx_list[start:end]
1329
1380
 
1330
1381
  for dataset_idx in range(len(self.datasets)):
1331
- dataset = self.datasets[self.dataset_idx2key_map[dataset_idx]]
1332
1382
  sub_requested_seqs = [s[1] for s in requested_seqs if s[0] == dataset_idx]
1333
1383
  if not sub_requested_seqs:
1334
1384
  continue
1335
1385
  sub_start, sub_end = min(sub_requested_seqs), max(sub_requested_seqs)
1336
- dataset.load_seqs(sub_start, sub_end + 1)
1386
+ self._sub_dataset_load_seqs(dataset_idx, sub_start, sub_end + 1)
1337
1387
  super(CombinedDataset, self)._load_seqs(start=start, end=end)
1338
1388
 
1389
+ def _sub_dataset_load_seqs(self, dataset_idx: int, start: int, end: int):
1390
+ self._sub_dataset_cur_loaded_seq_range[dataset_idx] = (start, end)
1391
+ dataset = self.datasets[self.dataset_idx2key_map[dataset_idx]]
1392
+ dataset.load_seqs(start, end)
1393
+
1339
1394
  def _get_data(self, dataset_key, dataset_seq_idx, data_key):
1340
1395
  """
1341
1396
  :type dataset_seq_idx: int
@@ -1365,19 +1420,30 @@ class CombinedDataset(CachedDataset2):
1365
1420
  dataset = self.datasets[dataset_key]
1366
1421
 
1367
1422
  seq_tag = dataset.get_tag(dataset_seq_idx)
1368
- features = self._get_data(dataset_key, dataset_seq_idx, "data")
1369
- targets = {target: self._get_data(dataset_key, dataset_seq_idx, target) for target in self.target_list}
1370
- return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
1423
+ features = {key: self._get_data(dataset_key, dataset_seq_idx, key) for key in self.data_keys}
1424
+ complete_frac = None
1425
+ if self.seq_ordering == "interleave":
1426
+ # In the interleave case, by design, this should be monotonically increasing,
1427
+ # as per how we select the next seq in _expand_dataset_seq_idxs.
1428
+ complete_frac = dataset.get_complete_frac(dataset_seq_idx, allow_only_lr_suitable=True)
1429
+ # In other cases, complete_frac is not so straightforward.
1430
+ # In the case that the total num seqs is known, then it's anyway not necessary.
1431
+ return DatasetSeq(seq_idx=seq_idx, complete_frac=complete_frac, seq_tag=seq_tag, features=features)
1371
1432
 
1372
- def is_less_than_num_seqs(self, n):
1433
+ def is_less_than_num_seqs(self, n: int) -> bool:
1373
1434
  """
1374
- :param int n:
1375
- :rtype: bool
1435
+ :param n:
1376
1436
  """
1377
1437
  if n < len(self.dataset_sorted_seq_idx_list):
1378
1438
  return True
1379
1439
  else:
1380
- return self._expand_dataset_sec_idxs(n - len(self.dataset_sorted_seq_idx_list) + 1)
1440
+ return self._expand_dataset_seq_idxs(n - len(self.dataset_sorted_seq_idx_list) + 1)
1441
+
1442
+ def get_data_keys(self) -> List[str]:
1443
+ """data keys"""
1444
+ if "data" in self.data_keys:
1445
+ return ["data"] + sorted(self.data_keys - {"data"})
1446
+ return sorted(self.data_keys)
1381
1447
 
1382
1448
  def get_target_list(self):
1383
1449
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20251027.117
3
+ Version: 1.20251027.232712
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,10 +1,10 @@
1
- returnn/PKG-INFO,sha256=5Pb1iE8plEOp8u6YgK8RC_SSyBmyhFba_D-gcXEE8YI,5212
1
+ returnn/PKG-INFO,sha256=XlAffW31FeRzj4iXwdobRyd-HUqyerhGuIjKnXR-eso,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=KECmOblD-dsBEVI8f_tn-BVnMF4NTy5DuhuYtunMF1M,77
7
- returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
6
+ returnn/_setup_info_generated.py,sha256=mh5Yk4VnansGboCO60Z0keWwnBbHaMw8ywduxfJ0gLM,77
7
+ returnn/config.py,sha256=JK8EjDsUdyY2c90s0KY1rLD1kesVfz6vRT0gxy_AQ5I,29142
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
10
10
  returnn/log.py,sha256=WoTDv4XDovgvgXa7iiav-nA8pb25lOEzndbnVrDLfUo,12319
@@ -18,12 +18,12 @@ returnn/datasets/bundle_file.py,sha256=KQNrS1MSf-4_idlK0c0KFwON-f5sEK0sWU15WpoMY
18
18
  returnn/datasets/cached.py,sha256=RyefRjSDdp-HveK-2vLy2C6BIHcpqQ_lNvUKlIa4QAI,25412
19
19
  returnn/datasets/cached2.py,sha256=oJOq2lWRQpxm6kyUKW1w5qZBd4kdKEpwM7KY_QnXbq4,11922
20
20
  returnn/datasets/distrib_files.py,sha256=48edqdf7YpnPJ-TOis3Mz5U9A2DSxfiYT1HCMSti3zw,32718
21
- returnn/datasets/generating.py,sha256=Qb7V94N_GfL2pZPxWS5PmzszoVXXKzuUmsHuW3dmVbc,99556
21
+ returnn/datasets/generating.py,sha256=o9-JZ2s5QKssux6GcSaM3oivf_PE6nhSOeytRyGB7pQ,99574
22
22
  returnn/datasets/hdf.py,sha256=v5sjBenURR9Z-g7AQ9tsL84yDSye5RtbLpym3M6HSDE,67833
23
23
  returnn/datasets/huggingface.py,sha256=ls9WMR6gUcMgGksl80g0An1az5Xjya_V3ojbbbsZqrU,20047
24
24
  returnn/datasets/lm.py,sha256=rQ3jV43lSnlGkKu7m5jTTH7aK0BOMXQocsHfJ8OGec8,99950
25
25
  returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
26
- returnn/datasets/meta.py,sha256=E1ZOlIMk4PiNMd5bUCnxdAU7K2hLYEY4Jn6GqbFjjMw,95850
26
+ returnn/datasets/meta.py,sha256=VJ5bk8esq2-b9likNSrCsHQKiLC3Vvti5oBAxg-AsIk,99422
27
27
  returnn/datasets/multi_proc.py,sha256=BClXq0fActi1XQa4vcMhHmhYF0Q-fnnDzlIlbBM6_DM,22614
28
28
  returnn/datasets/normalization_data.py,sha256=J3njQCMvWAbIAVPepO2L_Xdau9eWYB7Zyd6STeGzTbc,14615
29
29
  returnn/datasets/numpy_dump.py,sha256=wl8bKIKAlff2HPJPtuu5wBg3TLOf16d2wLVB4lLAwTM,5158
@@ -255,8 +255,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
255
255
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
256
256
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
257
257
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
258
- returnn-1.20251027.117.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
259
- returnn-1.20251027.117.dist-info/METADATA,sha256=5Pb1iE8plEOp8u6YgK8RC_SSyBmyhFba_D-gcXEE8YI,5212
260
- returnn-1.20251027.117.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
261
- returnn-1.20251027.117.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
262
- returnn-1.20251027.117.dist-info/RECORD,,
258
+ returnn-1.20251027.232712.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
259
+ returnn-1.20251027.232712.dist-info/METADATA,sha256=XlAffW31FeRzj4iXwdobRyd-HUqyerhGuIjKnXR-eso,5215
260
+ returnn-1.20251027.232712.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
261
+ returnn-1.20251027.232712.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
262
+ returnn-1.20251027.232712.dist-info/RECORD,,