returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/config.py +1 -1
  4. returnn/datasets/basic.py +29 -13
  5. returnn/datasets/distrib_files.py +61 -3
  6. returnn/datasets/generating.py +12 -21
  7. returnn/datasets/huggingface.py +434 -0
  8. returnn/datasets/lm.py +20 -0
  9. returnn/datasets/meta.py +179 -60
  10. returnn/datasets/multi_proc.py +1 -1
  11. returnn/datasets/postprocessing.py +597 -108
  12. returnn/datasets/text_dict.py +1 -1
  13. returnn/datasets/util/vocabulary.py +90 -0
  14. returnn/frontend/_backend.py +7 -0
  15. returnn/frontend/array_.py +54 -1
  16. returnn/frontend/attention.py +54 -20
  17. returnn/frontend/conv.py +273 -54
  18. returnn/frontend/decoder/transformer.py +36 -17
  19. returnn/frontend/encoder/conformer.py +1 -0
  20. returnn/frontend/encoder/transformer.py +2 -0
  21. returnn/frontend/loss.py +40 -1
  22. returnn/frontend/module.py +8 -1
  23. returnn/frontend/nested.py +9 -0
  24. returnn/native_op.cpp +80 -0
  25. returnn/sprint/cache.py +12 -13
  26. returnn/tensor/_dim_extra.py +51 -29
  27. returnn/tensor/_tensor_extra.py +6 -1
  28. returnn/tensor/utils.py +7 -4
  29. returnn/tf/frontend_layers/_backend.py +11 -2
  30. returnn/tf/frontend_low_level/_backend.py +15 -0
  31. returnn/tf/layers/basic.py +16 -38
  32. returnn/tf/native_op.py +11 -58
  33. returnn/tf/network.py +1 -1
  34. returnn/tf/util/basic.py +19 -0
  35. returnn/torch/data/returnn_dataset_wrapper.py +9 -3
  36. returnn/torch/engine.py +67 -2
  37. returnn/torch/frontend/_backend.py +119 -7
  38. returnn/torch/util/diagnose_gpu.py +65 -31
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/util/basic.py +6 -7
  41. returnn/util/better_exchook.py +4 -0
  42. returnn/util/collect_outputs_dict.py +79 -0
  43. returnn/util/debug.py +11 -2
  44. returnn/util/file_cache.py +42 -4
  45. returnn/util/task_system.py +1 -1
  46. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
  47. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
  48. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
  49. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
  50. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
returnn/datasets/meta.py CHANGED
@@ -253,22 +253,12 @@ class MetaDataset(CachedDataset2):
253
253
  }
254
254
 
255
255
  self._seq_list_file = seq_list_file
256
- self.seq_list_original = self._load_seq_list(seq_list_file)
257
- self.num_total_seqs = len(self.seq_list_original[self.default_dataset_key])
258
- for key in self.dataset_keys:
259
- assert len(self.seq_list_original[key]) == self.num_total_seqs
260
-
261
- self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
256
+ self.seq_list_original: Optional[Dict[str, List[str]]] = None
257
+ self.tag_idx: Optional[Dict[str, int]] = None
262
258
 
263
259
  self._seq_lens: Optional[Dict[str, NumbersDict]] = None
264
260
  self._num_timesteps: Optional[NumbersDict] = None
265
261
  self._seq_lens_file = seq_lens_file
266
- if seq_lens_file:
267
- seq_lens = load_json(filename=seq_lens_file)
268
- assert isinstance(seq_lens, dict)
269
- # dict[str,NumbersDict], seq-tag -> data-key -> len
270
- self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
271
- self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original[self.default_dataset_key]])
272
262
 
273
263
  if data_dims:
274
264
  data_dims = convert_data_dims(data_dims)
@@ -290,19 +280,20 @@ class MetaDataset(CachedDataset2):
290
280
  self.num_outputs = self.data_dims
291
281
 
292
282
  self.orig_seq_order_is_initialized = False
283
+ self._current_seq_order: List[int] = []
293
284
  self.seq_list_ordered: Optional[Dict[str, List[str]]] = None
294
285
 
295
- def _load_seq_list(self, seq_list_file: Optional[Union[str, Dict[str, str]]] = None) -> Dict[str, List[str]]:
296
- """
297
- :param seq_list_file:
298
- :return: dict: dataset key -> seq list
299
- """
300
- if not seq_list_file:
286
+ def _lazy_init_seq_list(self):
287
+ if self.seq_list_original is not None:
288
+ return
289
+
290
+ if not self._seq_list_file:
301
291
  # We create a sequence list from all the sequences of the default dataset
302
292
  # and hope that it also applies to the
303
293
  # other datasets.
304
294
  # This can only work if all datasets have the same tag format and the sequences in the other
305
295
  # datasets are a subset of those in the default dataset.
296
+ # (But the order does not matter.)
306
297
  default_dataset = self.datasets[self.default_dataset_key]
307
298
  assert isinstance(default_dataset, Dataset)
308
299
  print(
@@ -349,17 +340,18 @@ class MetaDataset(CachedDataset2):
349
340
  break # only print one
350
341
  del seq_list_set
351
342
  raise Exception("Dataset %r is missing seqs." % key)
352
- elif isinstance(seq_list_file, str):
353
- seq_list = Dataset._load_seq_list_file(seq_list_file, expect_list=False)
354
- elif isinstance(seq_list_file, dict):
343
+ elif isinstance(self._seq_list_file, str):
344
+ seq_list = Dataset._load_seq_list_file(self._seq_list_file, expect_list=False)
345
+ elif isinstance(self._seq_list_file, dict):
355
346
  for key in self.dataset_keys:
356
- if key not in seq_list_file:
347
+ if key not in self._seq_list_file:
357
348
  raise ValueError(f"seq_list_file does not contain all datasets, missing {key}")
358
- seq_list = {key: Dataset._load_seq_list_file(seq_list_file[key]) for key in self.dataset_keys}
349
+ seq_list = {key: Dataset._load_seq_list_file(self._seq_list_file[key]) for key in self.dataset_keys}
359
350
  else:
360
- raise TypeError(f"unexpected seq_list_file type {type(seq_list_file)}")
351
+ raise TypeError(f"unexpected seq_list_file type {type(self._seq_list_file)}")
361
352
 
362
353
  if isinstance(seq_list, list):
354
+ # Use same seq list for all datasets
363
355
  seq_list = {key: seq_list for key in self.dataset_keys}
364
356
  elif isinstance(seq_list, dict):
365
357
  for key in self.dataset_keys:
@@ -368,10 +360,29 @@ class MetaDataset(CachedDataset2):
368
360
  else:
369
361
  raise TypeError(f"unexpected seq_list type {type(seq_list)}")
370
362
 
371
- return seq_list
363
+ for key in self.dataset_keys:
364
+ assert len(seq_list[key]) == len(seq_list[self.default_dataset_key])
365
+
366
+ self.seq_list_original = seq_list
367
+
368
+ def _lazy_init_tag_idx(self):
369
+ if self.tag_idx is not None:
370
+ return
371
+ self._lazy_init_seq_list()
372
+ self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
373
+
374
+ def _lazy_init_seq_lens(self):
375
+ if self._seq_lens is not None:
376
+ return
377
+ assert self._seq_lens_file
378
+ seq_lens = load_json(filename=self._seq_lens_file)
379
+ assert isinstance(seq_lens, dict)
380
+ # dict[str,NumbersDict], seq-tag -> data-key -> len
381
+ self._seq_lens = {tag: NumbersDict(lens) for (tag, lens) in seq_lens.items()}
372
382
 
373
383
  def _get_dataset_seq_length(self, seq_idx: int):
374
384
  if not self.orig_seq_order_is_initialized:
385
+ self._lazy_init_seq_list()
375
386
  # To use get_seq_length() we first have to init the sequence order once in original order.
376
387
  # If sequence lengths are not needed by get_seq_order_for_epoch this is never executed.
377
388
  self.datasets[self.default_dataset_key].init_seq_order(
@@ -379,6 +390,9 @@ class MetaDataset(CachedDataset2):
379
390
  )
380
391
  self.orig_seq_order_is_initialized = True
381
392
 
393
+ # Warning: This is not correct in the general case.
394
+ # get_seq_length needs to have load_seqs called beforehand per API contract.
395
+ # For some datasets, it might anyway work.
382
396
  return self.datasets[self.default_dataset_key].get_seq_length(seq_idx)["data"]
383
397
 
384
398
  def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
@@ -392,6 +406,7 @@ class MetaDataset(CachedDataset2):
392
406
  self.epoch is None
393
407
  or self.epoch != epoch
394
408
  or self.seq_list_ordered is None
409
+ or not self._current_seq_order
395
410
  or seq_list is not None
396
411
  or seq_order is not None
397
412
  or self.expected_load_seq_start > 0
@@ -401,16 +416,17 @@ class MetaDataset(CachedDataset2):
401
416
  # This is called via initialize() with epoch=None, just to init some other things.
402
417
  # We are not expected to have prepared any real epoch here.
403
418
  self._num_seqs = 0
419
+ self._current_seq_order = []
404
420
  return True
405
421
 
406
422
  if not need_reinit:
407
- self._num_seqs = len(self.seq_list_ordered[self.default_dataset_key])
408
423
  return False
409
424
 
410
425
  seq_order_dataset = None
411
426
  if seq_order is not None:
412
427
  seq_index = seq_order
413
428
  elif seq_list is not None:
429
+ self._lazy_init_tag_idx()
414
430
  seq_index = [self.tag_idx[tag] for tag in seq_list]
415
431
  elif self.seq_order_control_dataset:
416
432
  seq_order_dataset = self.datasets[self.seq_order_control_dataset]
@@ -418,13 +434,15 @@ class MetaDataset(CachedDataset2):
418
434
  seq_order_dataset.init_seq_order(epoch=epoch)
419
435
  seq_index = seq_order_dataset.get_current_seq_order()
420
436
  else:
421
- if self._seq_lens:
437
+ if self._seq_lens_file:
422
438
 
423
439
  def get_seq_len(s):
424
440
  """
425
441
  :param int s:
426
442
  :rtype: int
427
443
  """
444
+ self._lazy_init_seq_list()
445
+ self._lazy_init_seq_lens()
428
446
  return self._seq_lens[self.seq_list_original[self.default_dataset_key][s]]["data"]
429
447
 
430
448
  elif self._seq_order_seq_lens_file:
@@ -432,8 +450,10 @@ class MetaDataset(CachedDataset2):
432
450
  else:
433
451
  self.orig_seq_order_is_initialized = False
434
452
  get_seq_len = self._get_dataset_seq_length
435
- seq_index = self.get_seq_order_for_epoch(epoch, self.num_total_seqs, get_seq_len)
453
+ seq_index = self.get_seq_order_for_epoch(epoch, self.get_total_num_seqs(), get_seq_len)
436
454
  self._num_seqs = len(seq_index)
455
+ self._current_seq_order = seq_index
456
+ self._lazy_init_seq_list()
437
457
  self.seq_list_ordered = {key: [ls[s] for s in seq_index] for (key, ls) in self.seq_list_original.items()}
438
458
 
439
459
  for dataset_key, dataset in self.datasets.items():
@@ -447,7 +467,7 @@ class MetaDataset(CachedDataset2):
447
467
  """supports sorting"""
448
468
  if self.seq_order_control_dataset:
449
469
  return self.datasets[self.seq_order_control_dataset].supports_seq_order_sorting()
450
- if self._seq_lens or self._seq_order_seq_lens_file:
470
+ if self._seq_lens_file or self._seq_order_seq_lens_file:
451
471
  return True
452
472
  return False
453
473
 
@@ -464,20 +484,40 @@ class MetaDataset(CachedDataset2):
464
484
  :return: current seq order for the current epoch, after self.init_seq_order was called.
465
485
  :rtype: list[int]
466
486
  """
467
- return [self.tag_idx[tag] for tag in self.seq_list_ordered[self.default_dataset_key]]
487
+ return self._current_seq_order
468
488
 
469
489
  def get_all_tags(self):
470
490
  """
471
491
  :return: list of all seq tags, of the whole dataset, without partition epoch
472
492
  :rtype: list[str]
473
493
  """
494
+ if self._seq_list_file is None:
495
+ return self.datasets[self.default_dataset_key].get_all_tags()
496
+ self._lazy_init_seq_list()
497
+ assert self.seq_list_original is not None
474
498
  return self.seq_list_original[self.default_dataset_key]
475
499
 
476
500
  def get_total_num_seqs(self, *, fast: bool = False) -> int:
477
501
  """
502
+ :param fast: if True, might raise an exception if not possible to get fast.
478
503
  :return: total number of seqs, without partition epoch
479
504
  """
480
- return self.num_total_seqs
505
+ if self._seq_list_file is None:
506
+ return self.datasets[self.default_dataset_key].get_total_num_seqs(fast=fast)
507
+ if fast and self.seq_list_original is None:
508
+ raise OptionalNotImplementedError(f"{self} get_total_num_seqs, seq list not loaded yet")
509
+ self._lazy_init_seq_list()
510
+ assert self.seq_list_original is not None
511
+ return len(self.seq_list_original[self.default_dataset_key])
512
+
513
+ def get_num_timesteps(self):
514
+ """num timesteps"""
515
+ if self._num_timesteps is None and self._seq_lens_file:
516
+ self._lazy_init_seq_lens()
517
+ self._num_timesteps = sum([self._seq_lens[s] for s in self.get_all_tags()], start=NumbersDict())
518
+ if self._seq_list_file is None:
519
+ return self.datasets[self.default_dataset_key].get_num_timesteps()
520
+ return super().get_num_timesteps()
481
521
 
482
522
  def finish_epoch(self, *, free_resources: bool = False):
483
523
  """
@@ -503,8 +543,9 @@ class MetaDataset(CachedDataset2):
503
543
  if start_ < end:
504
544
  for dataset_key in self.dataset_keys:
505
545
  self.datasets[dataset_key].load_seqs(start_, end)
506
- for seq_idx in range(start_, end):
507
- self._check_dataset_seq(dataset_key, seq_idx)
546
+ if self.seq_list_ordered is not None:
547
+ for seq_idx in range(start_, end):
548
+ self._check_dataset_seq(dataset_key, seq_idx)
508
549
  super(MetaDataset, self)._load_seqs(start=start, end=end)
509
550
 
510
551
  def _check_dataset_seq(self, dataset_key, seq_idx):
@@ -531,7 +572,7 @@ class MetaDataset(CachedDataset2):
531
572
  :type seq_idx: int
532
573
  :rtype: DatasetSeq
533
574
  """
534
- seq_tag = self.seq_list_ordered[self.default_dataset_key][seq_idx]
575
+ seq_tag = self.get_tag(seq_idx)
535
576
  features = {data_key: self._get_data(seq_idx, data_key) for data_key in self.data_keys}
536
577
  return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
537
578
 
@@ -540,8 +581,9 @@ class MetaDataset(CachedDataset2):
540
581
  :param int sorted_seq_idx:
541
582
  :rtype: NumbersDict
542
583
  """
543
- if self._seq_lens:
544
- return self._seq_lens[self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]]
584
+ if self._seq_lens_file:
585
+ self._lazy_init_seq_lens()
586
+ return self._seq_lens[self.get_tag(sorted_seq_idx)]
545
587
  return super(MetaDataset, self).get_seq_length(sorted_seq_idx)
546
588
 
547
589
  def get_tag(self, sorted_seq_idx):
@@ -549,7 +591,10 @@ class MetaDataset(CachedDataset2):
549
591
  :param int sorted_seq_idx:
550
592
  :rtype: str
551
593
  """
552
- return self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]
594
+ if self.seq_list_ordered is not None:
595
+ return self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]
596
+ else:
597
+ return self.datasets[self.default_dataset_key].get_tag(sorted_seq_idx)
553
598
 
554
599
  def get_complete_frac(self, sorted_seq_idx: int, **kwargs) -> Optional[float]:
555
600
  """
@@ -961,10 +1006,10 @@ class CombinedDataset(CachedDataset2):
961
1006
  super(CombinedDataset, self).__init__(**kwargs)
962
1007
  assert self.shuffle_frames_of_nseqs == 0 # not implemented. anyway only for non-recurrent nets
963
1008
 
1009
+ self.data_map = data_map
964
1010
  self.dataset_keys = set([m[0] for m in data_map.keys()]) # type: typing.Set[str]
965
1011
  self.dataset_idx2key_map = dict(enumerate(sorted(self.dataset_keys))) # idx -> dataset-key
966
1012
  self.data_keys = set(data_map.values()) # type: typing.Set[str]
967
- assert "data" in self.data_keys
968
1013
  self.target_list = sorted(self.data_keys - {"data"})
969
1014
 
970
1015
  # Build target lookup table that maps from dataset_key and data_key (data key used by CombinedDataset)
@@ -994,8 +1039,7 @@ class CombinedDataset(CachedDataset2):
994
1039
  if data_dims:
995
1040
  data_dims = convert_data_dims(data_dims)
996
1041
  self.data_dims = data_dims
997
- assert "data" in data_dims
998
- for key in self.target_list:
1042
+ for key in self.data_keys:
999
1043
  assert key in data_dims
1000
1044
  else:
1001
1045
  self.data_dims = {}
@@ -1009,7 +1053,7 @@ class CombinedDataset(CachedDataset2):
1009
1053
  if dataset_data_key in dataset.labels:
1010
1054
  self.labels[data_key] = dataset.labels[dataset_data_key]
1011
1055
 
1012
- self.num_inputs = self.data_dims["data"][0]
1056
+ self.num_inputs = self.data_dims["data"][0] if "data" in self.data_dims else 0
1013
1057
  self.num_outputs = self.data_dims
1014
1058
 
1015
1059
  self.data_dtypes = {
@@ -1019,6 +1063,9 @@ class CombinedDataset(CachedDataset2):
1019
1063
 
1020
1064
  self.dataset_seq_idx_boundaries: Optional[List[int]] = None
1021
1065
  self.dataset_sorted_seq_idx_list: Optional[List[Tuple[int, int]]] = None
1066
+ self._sub_dataset_cur_loaded_seq_range: Optional[List[Tuple[int, int]]] = None
1067
+ # The usage is about the seqs already covered in dataset_sorted_seq_idx_list,
1068
+ # in case we dynamically build up this list.
1022
1069
  self.used_num_seqs_per_subset: Optional[List[int]] = None
1023
1070
 
1024
1071
  def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
@@ -1030,7 +1077,7 @@ class CombinedDataset(CachedDataset2):
1030
1077
  """
1031
1078
 
1032
1079
  assert seq_list is None and seq_order is None, "seq_list and seq_order not supported for %s" % self.__class__
1033
- need_reinit = self.epoch is None or self.epoch != epoch
1080
+ need_reinit = self.epoch is None or self.epoch != epoch or self.expected_load_seq_start > 0
1034
1081
  num_seqs_saved = self._num_seqs
1035
1082
  super(CombinedDataset, self).init_seq_order(
1036
1083
  epoch=epoch, seq_list=seq_list, seq_order=seq_order
@@ -1047,13 +1094,15 @@ class CombinedDataset(CachedDataset2):
1047
1094
  for dataset in self.datasets.values():
1048
1095
  dataset.init_seq_order(epoch=epoch)
1049
1096
 
1097
+ self._sub_dataset_cur_loaded_seq_range = [(0, 0)] * len(self.datasets)
1098
+
1050
1099
  # noinspection PyBroadException
1051
1100
  try:
1052
1101
  total_num_seqs = sum([self.datasets[k].num_seqs for k in sorted(self.datasets.keys())])
1053
1102
  except Exception:
1054
1103
  total_num_seqs = None
1055
1104
 
1056
- if total_num_seqs is not None:
1105
+ if total_num_seqs is not None and self.seq_ordering != "interleave":
1057
1106
  self.dataset_seq_idx_boundaries = self._create_dataset_seq_idx_boundaries()
1058
1107
 
1059
1108
  if self.sampling_sizes:
@@ -1090,7 +1139,7 @@ class CombinedDataset(CachedDataset2):
1090
1139
 
1091
1140
  # Re-initialize sequence orders of sub-datasets with created sequence list.
1092
1141
  self.used_num_seqs_per_subset = []
1093
- for dataset_idx, dataset_key in self.dataset_idx2key_map.items():
1142
+ for dataset_idx, dataset_key in sorted(self.dataset_idx2key_map.items()):
1094
1143
  assert self.datasets[dataset_key].have_corpus_seq_idx()
1095
1144
  self.datasets[dataset_key].init_seq_order(epoch=epoch, seq_order=seq_order_subdatasets[dataset_idx])
1096
1145
  self.used_num_seqs_per_subset.append(len(seq_order_subdatasets[dataset_idx]))
@@ -1098,6 +1147,11 @@ class CombinedDataset(CachedDataset2):
1098
1147
  else:
1099
1148
  self.dataset_sorted_seq_idx_list = [] # We will fill this as we go
1100
1149
  self.used_num_seqs_per_subset = [0] * len(self.datasets)
1150
+ self._num_seqs = total_num_seqs
1151
+
1152
+ # These are currently not supported/implemented.
1153
+ # All of these should just be done in the sub-datasets directly.
1154
+ assert self.partition_epoch == 1 and self.repeat_epoch == 1 and self._num_shards == 1
1101
1155
 
1102
1156
  return True
1103
1157
 
@@ -1236,13 +1290,34 @@ class CombinedDataset(CachedDataset2):
1236
1290
 
1237
1291
  return dataset.get_estimated_seq_length(dataset_seq_idx)
1238
1292
 
1239
- def _expand_dataset_sec_idxs(self, num_values):
1293
+ def _sub_dataset_make_cur_loaded(self, dataset_idx: int) -> bool:
1294
+ # Cur meaning for the next sequence to be added to dataset_sorted_seq_idx_list.
1295
+ seq_idx = self.used_num_seqs_per_subset[dataset_idx]
1296
+ cur_start, cur_end = self._sub_dataset_cur_loaded_seq_range[dataset_idx]
1297
+
1298
+ if not self.datasets[self.dataset_idx2key_map[dataset_idx]].is_less_than_num_seqs(seq_idx):
1299
+ return False
1300
+
1301
+ if seq_idx >= cur_end:
1302
+ self._sub_dataset_load_seqs(dataset_idx, cur_start, seq_idx + 1)
1303
+ return True
1304
+ elif seq_idx < cur_start:
1305
+ return False
1306
+ else:
1307
+ return True
1308
+
1309
+ def _expand_dataset_seq_idxs(self, num_values: int) -> bool:
1240
1310
  """
1241
- :param int num_values: Add num_values entries to the dataset-segment-idx mapping table
1242
- :return: something?
1243
- :rtype: bool
1311
+ Try to extend dataset_sorted_seq_idx_list.
1312
+ We expect that we have reached the end of it.
1313
+
1314
+ :param num_values: Add num_values entries to the dataset-segment-idx mapping table
1315
+ :return: whether we added num_values entries
1244
1316
  """
1245
- for i in range(num_values):
1317
+ for _ in range(num_values):
1318
+ for j in range(len(self.datasets)):
1319
+ self._sub_dataset_make_cur_loaded(j)
1320
+
1246
1321
  if self.seq_ordering == "default": # i.e. in order
1247
1322
  dataset_idx = 0
1248
1323
  while dataset_idx < len(self.datasets):
@@ -1265,6 +1340,32 @@ class CombinedDataset(CachedDataset2):
1265
1340
  else:
1266
1341
  return False # No dataset has remaining data
1267
1342
 
1343
+ elif self.seq_ordering == "interleave":
1344
+ complete_fracs_and_ds_idx = [
1345
+ (
1346
+ self.datasets[self.dataset_idx2key_map[j]].get_complete_frac(
1347
+ self.used_num_seqs_per_subset[j], allow_only_lr_suitable=True
1348
+ )
1349
+ if self.datasets[self.dataset_idx2key_map[j]].is_less_than_num_seqs(
1350
+ self.used_num_seqs_per_subset[j]
1351
+ )
1352
+ else float("inf"),
1353
+ j,
1354
+ )
1355
+ for j in range(len(self.datasets))
1356
+ ]
1357
+ assert all(frac is not None for frac, _ in complete_fracs_and_ds_idx), (
1358
+ f"{self}: Datasets must provide complete frac for interleave,"
1359
+ f" got {complete_fracs_and_ds_idx}, dataset idx2key map {self.dataset_idx2key_map}"
1360
+ )
1361
+ # Sort by complete frac, i.e. datasets with the lowest complete frac first.
1362
+ complete_fracs_and_ds_idx.sort()
1363
+ for complete_frac, dataset_idx in complete_fracs_and_ds_idx:
1364
+ if complete_frac < float("inf"):
1365
+ break
1366
+ else:
1367
+ return False # No dataset has remaining data
1368
+
1268
1369
  elif self.seq_ordering == "random_dataset":
1269
1370
  while True:
1270
1371
  # Build probability table
@@ -1323,19 +1424,23 @@ class CombinedDataset(CachedDataset2):
1323
1424
  def _load_seqs(self, start, end):
1324
1425
  # If the segment order is not yet known, fix the next few segments
1325
1426
  if end > len(self.dataset_sorted_seq_idx_list):
1326
- self._expand_dataset_sec_idxs(end - len(self.dataset_sorted_seq_idx_list))
1427
+ self._expand_dataset_seq_idxs(end - len(self.dataset_sorted_seq_idx_list))
1327
1428
 
1328
1429
  requested_seqs = self.dataset_sorted_seq_idx_list[start:end]
1329
1430
 
1330
1431
  for dataset_idx in range(len(self.datasets)):
1331
- dataset = self.datasets[self.dataset_idx2key_map[dataset_idx]]
1332
1432
  sub_requested_seqs = [s[1] for s in requested_seqs if s[0] == dataset_idx]
1333
1433
  if not sub_requested_seqs:
1334
1434
  continue
1335
1435
  sub_start, sub_end = min(sub_requested_seqs), max(sub_requested_seqs)
1336
- dataset.load_seqs(sub_start, sub_end + 1)
1436
+ self._sub_dataset_load_seqs(dataset_idx, sub_start, sub_end + 1)
1337
1437
  super(CombinedDataset, self)._load_seqs(start=start, end=end)
1338
1438
 
1439
+ def _sub_dataset_load_seqs(self, dataset_idx: int, start: int, end: int):
1440
+ self._sub_dataset_cur_loaded_seq_range[dataset_idx] = (start, end)
1441
+ dataset = self.datasets[self.dataset_idx2key_map[dataset_idx]]
1442
+ dataset.load_seqs(start, end)
1443
+
1339
1444
  def _get_data(self, dataset_key, dataset_seq_idx, data_key):
1340
1445
  """
1341
1446
  :type dataset_seq_idx: int
@@ -1348,7 +1453,10 @@ class CombinedDataset(CachedDataset2):
1348
1453
  if dataset_data_key is not None:
1349
1454
  return dataset.get_data(dataset_seq_idx, dataset_data_key)
1350
1455
  else:
1351
- return numpy.array([], self.data_dtypes[data_key])
1456
+ shape: List[int] = [0] * self.num_outputs[data_key][1]
1457
+ if shape and not self.is_data_sparse(data_key):
1458
+ shape[-1] = self.get_data_dim(data_key)
1459
+ return numpy.zeros(shape, dtype=self.data_dtypes[data_key])
1352
1460
 
1353
1461
  def _collect_single_seq(self, seq_idx):
1354
1462
  """
@@ -1362,19 +1470,30 @@ class CombinedDataset(CachedDataset2):
1362
1470
  dataset = self.datasets[dataset_key]
1363
1471
 
1364
1472
  seq_tag = dataset.get_tag(dataset_seq_idx)
1365
- features = self._get_data(dataset_key, dataset_seq_idx, "data")
1366
- targets = {target: self._get_data(dataset_key, dataset_seq_idx, target) for target in self.target_list}
1367
- return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
1473
+ features = {key: self._get_data(dataset_key, dataset_seq_idx, key) for key in self.data_keys}
1474
+ complete_frac = None
1475
+ if self.seq_ordering == "interleave":
1476
+ # In the interleave case, by design, this should be monotonically increasing,
1477
+ # as per how we select the next seq in _expand_dataset_seq_idxs.
1478
+ complete_frac = dataset.get_complete_frac(dataset_seq_idx, allow_only_lr_suitable=True)
1479
+ # In other cases, complete_frac is not so straightforward.
1480
+ # In the case that the total num seqs is known, then it's anyway not necessary.
1481
+ return DatasetSeq(seq_idx=seq_idx, complete_frac=complete_frac, seq_tag=seq_tag, features=features)
1368
1482
 
1369
- def is_less_than_num_seqs(self, n):
1483
+ def is_less_than_num_seqs(self, n: int) -> bool:
1370
1484
  """
1371
- :param int n:
1372
- :rtype: bool
1485
+ :param n:
1373
1486
  """
1374
1487
  if n < len(self.dataset_sorted_seq_idx_list):
1375
1488
  return True
1376
1489
  else:
1377
- return self._expand_dataset_sec_idxs(n - len(self.dataset_sorted_seq_idx_list) + 1)
1490
+ return self._expand_dataset_seq_idxs(n - len(self.dataset_sorted_seq_idx_list) + 1)
1491
+
1492
+ def get_data_keys(self) -> List[str]:
1493
+ """data keys"""
1494
+ if "data" in self.data_keys:
1495
+ return ["data"] + sorted(self.data_keys - {"data"})
1496
+ return sorted(self.data_keys)
1378
1497
 
1379
1498
  def get_target_list(self):
1380
1499
  """
@@ -42,7 +42,7 @@ class MultiProcDataset(CachedDataset2):
42
42
  """
43
43
  :param dataset: the dataset to use
44
44
  :param num_workers: number of workers to use
45
- :param buffer_size: buffer size for each worker, amount of seqs to prefetch
45
+ :param buffer_size: buffer size for each worker, number of seqs to prefetch
46
46
  :param sharding_method: which method to use for sharding the data across the worker procs.
47
47
  The default is ``seq_order``, which fetches the full list of seq indices,
48
48
  and then distributes shards of that to the other workers.