returnn 1.20251027.224345__py3-none-any.whl → 1.20260109.93428__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +2 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/config.py +1 -1
- returnn/datasets/lm.py +20 -0
- returnn/datasets/meta.py +93 -43
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/array_.py +46 -0
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +40 -1
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +80 -0
- returnn/sprint/cache.py +12 -13
- returnn/tensor/utils.py +7 -4
- returnn/tf/frontend_layers/_backend.py +4 -3
- returnn/tf/layers/basic.py +15 -39
- returnn/tf/native_op.py +11 -58
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +19 -0
- returnn/torch/engine.py +37 -3
- returnn/torch/frontend/_backend.py +135 -13
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/util/basic.py +3 -6
- returnn/util/better_exchook.py +4 -0
- returnn/util/debug.py +11 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/task_system.py +1 -1
- {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/METADATA +2 -2
- {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/RECORD +37 -37
- {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/LICENSE +0 -0
- {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/WHEEL +0 -0
- {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: returnn
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.20260109.93428
|
|
4
4
|
Summary: The RWTH extensible training framework for universal recurrent neural networks
|
|
5
5
|
Home-page: https://github.com/rwth-i6/returnn/
|
|
6
6
|
Author: Albert Zeyer
|
|
@@ -36,7 +36,7 @@ Welcome to RETURNN
|
|
|
36
36
|
`RETURNN paper 2018 <https://arxiv.org/abs/1805.05225>`_.
|
|
37
37
|
|
|
38
38
|
RETURNN - RWTH extensible training framework for universal recurrent neural networks,
|
|
39
|
-
is a
|
|
39
|
+
is a PyTorch/TensorFlow-based implementation of modern recurrent neural network architectures.
|
|
40
40
|
It is optimized for fast and reliable training of recurrent neural networks in a multi-GPU environment.
|
|
41
41
|
|
|
42
42
|
The high-level features and goals of RETURNN are:
|
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20260109.093428'
|
|
2
|
+
long_version = '1.20260109.093428+git.68426d7'
|
returnn/config.py
CHANGED
|
@@ -801,7 +801,7 @@ class SubProcCopyGlobalConfigPreInitFunc:
|
|
|
801
801
|
from returnn.log import log
|
|
802
802
|
from returnn import __old_mod_loader__
|
|
803
803
|
|
|
804
|
-
better_exchook.
|
|
804
|
+
better_exchook.setup_all()
|
|
805
805
|
__old_mod_loader__.disable_lazy_mod_loads()
|
|
806
806
|
|
|
807
807
|
if self.global_config:
|
returnn/datasets/lm.py
CHANGED
|
@@ -694,6 +694,26 @@ class LmDataset(CachedDataset2):
|
|
|
694
694
|
self.next_seq_idx = seq_idx + 1
|
|
695
695
|
return DatasetSeq(seq_idx=seq_idx, features=data, targets=targets, seq_tag=seq_tag)
|
|
696
696
|
|
|
697
|
+
def finish_epoch(self, *, free_resources: bool = False):
|
|
698
|
+
"""finish epoch"""
|
|
699
|
+
super().finish_epoch(free_resources=free_resources)
|
|
700
|
+
|
|
701
|
+
if free_resources:
|
|
702
|
+
self._orths_offsets_and_lens = None
|
|
703
|
+
if self._orth_mmaps is not None:
|
|
704
|
+
for m in self._orth_mmaps:
|
|
705
|
+
if m is not None:
|
|
706
|
+
m.close()
|
|
707
|
+
self._orth_mmaps = None
|
|
708
|
+
if self._orth_files is not None:
|
|
709
|
+
for f in self._orth_files:
|
|
710
|
+
if f is not None:
|
|
711
|
+
f.close()
|
|
712
|
+
self._orth_files = None
|
|
713
|
+
|
|
714
|
+
self._seq_list = None
|
|
715
|
+
self._seq_index_by_tag = None
|
|
716
|
+
|
|
697
717
|
|
|
698
718
|
def _is_bliss(filename):
|
|
699
719
|
"""
|
returnn/datasets/meta.py
CHANGED
|
@@ -253,22 +253,12 @@ class MetaDataset(CachedDataset2):
|
|
|
253
253
|
}
|
|
254
254
|
|
|
255
255
|
self._seq_list_file = seq_list_file
|
|
256
|
-
self.seq_list_original =
|
|
257
|
-
self.
|
|
258
|
-
for key in self.dataset_keys:
|
|
259
|
-
assert len(self.seq_list_original[key]) == self.num_total_seqs
|
|
260
|
-
|
|
261
|
-
self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
|
|
256
|
+
self.seq_list_original: Optional[Dict[str, List[str]]] = None
|
|
257
|
+
self.tag_idx: Optional[Dict[str, int]] = None
|
|
262
258
|
|
|
263
259
|
self._seq_lens: Optional[Dict[str, NumbersDict]] = None
|
|
264
260
|
self._num_timesteps: Optional[NumbersDict] = None
|
|
265
261
|
self._seq_lens_file = seq_lens_file
|
|
266
|
-
if seq_lens_file:
|
|
267
|
-
seq_lens = load_json(filename=seq_lens_file)
|
|
268
|
-
assert isinstance(seq_lens, dict)
|
|
269
|
-
# dict[str,NumbersDict], seq-tag -> data-key -> len
|
|
270
|
-
self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
|
|
271
|
-
self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original[self.default_dataset_key]])
|
|
272
262
|
|
|
273
263
|
if data_dims:
|
|
274
264
|
data_dims = convert_data_dims(data_dims)
|
|
@@ -290,19 +280,20 @@ class MetaDataset(CachedDataset2):
|
|
|
290
280
|
self.num_outputs = self.data_dims
|
|
291
281
|
|
|
292
282
|
self.orig_seq_order_is_initialized = False
|
|
283
|
+
self._current_seq_order: List[int] = []
|
|
293
284
|
self.seq_list_ordered: Optional[Dict[str, List[str]]] = None
|
|
294
285
|
|
|
295
|
-
def
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
if not seq_list_file:
|
|
286
|
+
def _lazy_init_seq_list(self):
|
|
287
|
+
if self.seq_list_original is not None:
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
if not self._seq_list_file:
|
|
301
291
|
# We create a sequence list from all the sequences of the default dataset
|
|
302
292
|
# and hope that it also applies to the
|
|
303
293
|
# other datasets.
|
|
304
294
|
# This can only work if all datasets have the same tag format and the sequences in the other
|
|
305
295
|
# datasets are a subset of those in the default dataset.
|
|
296
|
+
# (But the order does not matter.)
|
|
306
297
|
default_dataset = self.datasets[self.default_dataset_key]
|
|
307
298
|
assert isinstance(default_dataset, Dataset)
|
|
308
299
|
print(
|
|
@@ -349,17 +340,18 @@ class MetaDataset(CachedDataset2):
|
|
|
349
340
|
break # only print one
|
|
350
341
|
del seq_list_set
|
|
351
342
|
raise Exception("Dataset %r is missing seqs." % key)
|
|
352
|
-
elif isinstance(
|
|
353
|
-
seq_list = Dataset._load_seq_list_file(
|
|
354
|
-
elif isinstance(
|
|
343
|
+
elif isinstance(self._seq_list_file, str):
|
|
344
|
+
seq_list = Dataset._load_seq_list_file(self._seq_list_file, expect_list=False)
|
|
345
|
+
elif isinstance(self._seq_list_file, dict):
|
|
355
346
|
for key in self.dataset_keys:
|
|
356
|
-
if key not in
|
|
347
|
+
if key not in self._seq_list_file:
|
|
357
348
|
raise ValueError(f"seq_list_file does not contain all datasets, missing {key}")
|
|
358
|
-
seq_list = {key: Dataset._load_seq_list_file(
|
|
349
|
+
seq_list = {key: Dataset._load_seq_list_file(self._seq_list_file[key]) for key in self.dataset_keys}
|
|
359
350
|
else:
|
|
360
|
-
raise TypeError(f"unexpected seq_list_file type {type(
|
|
351
|
+
raise TypeError(f"unexpected seq_list_file type {type(self._seq_list_file)}")
|
|
361
352
|
|
|
362
353
|
if isinstance(seq_list, list):
|
|
354
|
+
# Use same seq list for all datasets
|
|
363
355
|
seq_list = {key: seq_list for key in self.dataset_keys}
|
|
364
356
|
elif isinstance(seq_list, dict):
|
|
365
357
|
for key in self.dataset_keys:
|
|
@@ -368,10 +360,29 @@ class MetaDataset(CachedDataset2):
|
|
|
368
360
|
else:
|
|
369
361
|
raise TypeError(f"unexpected seq_list type {type(seq_list)}")
|
|
370
362
|
|
|
371
|
-
|
|
363
|
+
for key in self.dataset_keys:
|
|
364
|
+
assert len(seq_list[key]) == len(seq_list[self.default_dataset_key])
|
|
365
|
+
|
|
366
|
+
self.seq_list_original = seq_list
|
|
367
|
+
|
|
368
|
+
def _lazy_init_tag_idx(self):
|
|
369
|
+
if self.tag_idx is not None:
|
|
370
|
+
return
|
|
371
|
+
self._lazy_init_seq_list()
|
|
372
|
+
self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
|
|
373
|
+
|
|
374
|
+
def _lazy_init_seq_lens(self):
|
|
375
|
+
if self._seq_lens is not None:
|
|
376
|
+
return
|
|
377
|
+
assert self._seq_lens_file
|
|
378
|
+
seq_lens = load_json(filename=self._seq_lens_file)
|
|
379
|
+
assert isinstance(seq_lens, dict)
|
|
380
|
+
# dict[str,NumbersDict], seq-tag -> data-key -> len
|
|
381
|
+
self._seq_lens = {tag: NumbersDict(lens) for (tag, lens) in seq_lens.items()}
|
|
372
382
|
|
|
373
383
|
def _get_dataset_seq_length(self, seq_idx: int):
|
|
374
384
|
if not self.orig_seq_order_is_initialized:
|
|
385
|
+
self._lazy_init_seq_list()
|
|
375
386
|
# To use get_seq_length() we first have to init the sequence order once in original order.
|
|
376
387
|
# If sequence lengths are not needed by get_seq_order_for_epoch this is never executed.
|
|
377
388
|
self.datasets[self.default_dataset_key].init_seq_order(
|
|
@@ -379,6 +390,9 @@ class MetaDataset(CachedDataset2):
|
|
|
379
390
|
)
|
|
380
391
|
self.orig_seq_order_is_initialized = True
|
|
381
392
|
|
|
393
|
+
# Warning: This is not correct in the general case.
|
|
394
|
+
# get_seq_length needs to have load_seqs called beforehand per API contract.
|
|
395
|
+
# For some datasets, it might anyway work.
|
|
382
396
|
return self.datasets[self.default_dataset_key].get_seq_length(seq_idx)["data"]
|
|
383
397
|
|
|
384
398
|
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
|
|
@@ -392,6 +406,7 @@ class MetaDataset(CachedDataset2):
|
|
|
392
406
|
self.epoch is None
|
|
393
407
|
or self.epoch != epoch
|
|
394
408
|
or self.seq_list_ordered is None
|
|
409
|
+
or not self._current_seq_order
|
|
395
410
|
or seq_list is not None
|
|
396
411
|
or seq_order is not None
|
|
397
412
|
or self.expected_load_seq_start > 0
|
|
@@ -401,16 +416,17 @@ class MetaDataset(CachedDataset2):
|
|
|
401
416
|
# This is called via initialize() with epoch=None, just to init some other things.
|
|
402
417
|
# We are not expected to have prepared any real epoch here.
|
|
403
418
|
self._num_seqs = 0
|
|
419
|
+
self._current_seq_order = []
|
|
404
420
|
return True
|
|
405
421
|
|
|
406
422
|
if not need_reinit:
|
|
407
|
-
self._num_seqs = len(self.seq_list_ordered[self.default_dataset_key])
|
|
408
423
|
return False
|
|
409
424
|
|
|
410
425
|
seq_order_dataset = None
|
|
411
426
|
if seq_order is not None:
|
|
412
427
|
seq_index = seq_order
|
|
413
428
|
elif seq_list is not None:
|
|
429
|
+
self._lazy_init_tag_idx()
|
|
414
430
|
seq_index = [self.tag_idx[tag] for tag in seq_list]
|
|
415
431
|
elif self.seq_order_control_dataset:
|
|
416
432
|
seq_order_dataset = self.datasets[self.seq_order_control_dataset]
|
|
@@ -418,13 +434,15 @@ class MetaDataset(CachedDataset2):
|
|
|
418
434
|
seq_order_dataset.init_seq_order(epoch=epoch)
|
|
419
435
|
seq_index = seq_order_dataset.get_current_seq_order()
|
|
420
436
|
else:
|
|
421
|
-
if self.
|
|
437
|
+
if self._seq_lens_file:
|
|
422
438
|
|
|
423
439
|
def get_seq_len(s):
|
|
424
440
|
"""
|
|
425
441
|
:param int s:
|
|
426
442
|
:rtype: int
|
|
427
443
|
"""
|
|
444
|
+
self._lazy_init_seq_list()
|
|
445
|
+
self._lazy_init_seq_lens()
|
|
428
446
|
return self._seq_lens[self.seq_list_original[self.default_dataset_key][s]]["data"]
|
|
429
447
|
|
|
430
448
|
elif self._seq_order_seq_lens_file:
|
|
@@ -432,8 +450,10 @@ class MetaDataset(CachedDataset2):
|
|
|
432
450
|
else:
|
|
433
451
|
self.orig_seq_order_is_initialized = False
|
|
434
452
|
get_seq_len = self._get_dataset_seq_length
|
|
435
|
-
seq_index = self.get_seq_order_for_epoch(epoch, self.
|
|
453
|
+
seq_index = self.get_seq_order_for_epoch(epoch, self.get_total_num_seqs(), get_seq_len)
|
|
436
454
|
self._num_seqs = len(seq_index)
|
|
455
|
+
self._current_seq_order = seq_index
|
|
456
|
+
self._lazy_init_seq_list()
|
|
437
457
|
self.seq_list_ordered = {key: [ls[s] for s in seq_index] for (key, ls) in self.seq_list_original.items()}
|
|
438
458
|
|
|
439
459
|
for dataset_key, dataset in self.datasets.items():
|
|
@@ -447,7 +467,7 @@ class MetaDataset(CachedDataset2):
|
|
|
447
467
|
"""supports sorting"""
|
|
448
468
|
if self.seq_order_control_dataset:
|
|
449
469
|
return self.datasets[self.seq_order_control_dataset].supports_seq_order_sorting()
|
|
450
|
-
if self.
|
|
470
|
+
if self._seq_lens_file or self._seq_order_seq_lens_file:
|
|
451
471
|
return True
|
|
452
472
|
return False
|
|
453
473
|
|
|
@@ -464,20 +484,40 @@ class MetaDataset(CachedDataset2):
|
|
|
464
484
|
:return: current seq order for the current epoch, after self.init_seq_order was called.
|
|
465
485
|
:rtype: list[int]
|
|
466
486
|
"""
|
|
467
|
-
return
|
|
487
|
+
return self._current_seq_order
|
|
468
488
|
|
|
469
489
|
def get_all_tags(self):
|
|
470
490
|
"""
|
|
471
491
|
:return: list of all seq tags, of the whole dataset, without partition epoch
|
|
472
492
|
:rtype: list[str]
|
|
473
493
|
"""
|
|
494
|
+
if self._seq_list_file is None:
|
|
495
|
+
return self.datasets[self.default_dataset_key].get_all_tags()
|
|
496
|
+
self._lazy_init_seq_list()
|
|
497
|
+
assert self.seq_list_original is not None
|
|
474
498
|
return self.seq_list_original[self.default_dataset_key]
|
|
475
499
|
|
|
476
500
|
def get_total_num_seqs(self, *, fast: bool = False) -> int:
|
|
477
501
|
"""
|
|
502
|
+
:param fast: if True, might raise an exception if not possible to get fast.
|
|
478
503
|
:return: total number of seqs, without partition epoch
|
|
479
504
|
"""
|
|
480
|
-
|
|
505
|
+
if self._seq_list_file is None:
|
|
506
|
+
return self.datasets[self.default_dataset_key].get_total_num_seqs(fast=fast)
|
|
507
|
+
if fast and self.seq_list_original is None:
|
|
508
|
+
raise OptionalNotImplementedError(f"{self} get_total_num_seqs, seq list not loaded yet")
|
|
509
|
+
self._lazy_init_seq_list()
|
|
510
|
+
assert self.seq_list_original is not None
|
|
511
|
+
return len(self.seq_list_original[self.default_dataset_key])
|
|
512
|
+
|
|
513
|
+
def get_num_timesteps(self):
|
|
514
|
+
"""num timesteps"""
|
|
515
|
+
if self._num_timesteps is None and self._seq_lens_file:
|
|
516
|
+
self._lazy_init_seq_lens()
|
|
517
|
+
self._num_timesteps = sum([self._seq_lens[s] for s in self.get_all_tags()], start=NumbersDict())
|
|
518
|
+
if self._seq_list_file is None:
|
|
519
|
+
return self.datasets[self.default_dataset_key].get_num_timesteps()
|
|
520
|
+
return super().get_num_timesteps()
|
|
481
521
|
|
|
482
522
|
def finish_epoch(self, *, free_resources: bool = False):
|
|
483
523
|
"""
|
|
@@ -503,8 +543,9 @@ class MetaDataset(CachedDataset2):
|
|
|
503
543
|
if start_ < end:
|
|
504
544
|
for dataset_key in self.dataset_keys:
|
|
505
545
|
self.datasets[dataset_key].load_seqs(start_, end)
|
|
506
|
-
|
|
507
|
-
|
|
546
|
+
if self.seq_list_ordered is not None:
|
|
547
|
+
for seq_idx in range(start_, end):
|
|
548
|
+
self._check_dataset_seq(dataset_key, seq_idx)
|
|
508
549
|
super(MetaDataset, self)._load_seqs(start=start, end=end)
|
|
509
550
|
|
|
510
551
|
def _check_dataset_seq(self, dataset_key, seq_idx):
|
|
@@ -531,7 +572,7 @@ class MetaDataset(CachedDataset2):
|
|
|
531
572
|
:type seq_idx: int
|
|
532
573
|
:rtype: DatasetSeq
|
|
533
574
|
"""
|
|
534
|
-
seq_tag = self.
|
|
575
|
+
seq_tag = self.get_tag(seq_idx)
|
|
535
576
|
features = {data_key: self._get_data(seq_idx, data_key) for data_key in self.data_keys}
|
|
536
577
|
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
|
|
537
578
|
|
|
@@ -540,8 +581,9 @@ class MetaDataset(CachedDataset2):
|
|
|
540
581
|
:param int sorted_seq_idx:
|
|
541
582
|
:rtype: NumbersDict
|
|
542
583
|
"""
|
|
543
|
-
if self.
|
|
544
|
-
|
|
584
|
+
if self._seq_lens_file:
|
|
585
|
+
self._lazy_init_seq_lens()
|
|
586
|
+
return self._seq_lens[self.get_tag(sorted_seq_idx)]
|
|
545
587
|
return super(MetaDataset, self).get_seq_length(sorted_seq_idx)
|
|
546
588
|
|
|
547
589
|
def get_tag(self, sorted_seq_idx):
|
|
@@ -549,7 +591,10 @@ class MetaDataset(CachedDataset2):
|
|
|
549
591
|
:param int sorted_seq_idx:
|
|
550
592
|
:rtype: str
|
|
551
593
|
"""
|
|
552
|
-
|
|
594
|
+
if self.seq_list_ordered is not None:
|
|
595
|
+
return self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]
|
|
596
|
+
else:
|
|
597
|
+
return self.datasets[self.default_dataset_key].get_tag(sorted_seq_idx)
|
|
553
598
|
|
|
554
599
|
def get_complete_frac(self, sorted_seq_idx: int, **kwargs) -> Optional[float]:
|
|
555
600
|
"""
|
|
@@ -961,6 +1006,7 @@ class CombinedDataset(CachedDataset2):
|
|
|
961
1006
|
super(CombinedDataset, self).__init__(**kwargs)
|
|
962
1007
|
assert self.shuffle_frames_of_nseqs == 0 # not implemented. anyway only for non-recurrent nets
|
|
963
1008
|
|
|
1009
|
+
self.data_map = data_map
|
|
964
1010
|
self.dataset_keys = set([m[0] for m in data_map.keys()]) # type: typing.Set[str]
|
|
965
1011
|
self.dataset_idx2key_map = dict(enumerate(sorted(self.dataset_keys))) # idx -> dataset-key
|
|
966
1012
|
self.data_keys = set(data_map.values()) # type: typing.Set[str]
|
|
@@ -1248,6 +1294,10 @@ class CombinedDataset(CachedDataset2):
|
|
|
1248
1294
|
# Cur meaning for the next sequence to be added to dataset_sorted_seq_idx_list.
|
|
1249
1295
|
seq_idx = self.used_num_seqs_per_subset[dataset_idx]
|
|
1250
1296
|
cur_start, cur_end = self._sub_dataset_cur_loaded_seq_range[dataset_idx]
|
|
1297
|
+
|
|
1298
|
+
if not self.datasets[self.dataset_idx2key_map[dataset_idx]].is_less_than_num_seqs(seq_idx):
|
|
1299
|
+
return False
|
|
1300
|
+
|
|
1251
1301
|
if seq_idx >= cur_end:
|
|
1252
1302
|
self._sub_dataset_load_seqs(dataset_idx, cur_start, seq_idx + 1)
|
|
1253
1303
|
return True
|
|
@@ -1294,10 +1344,12 @@ class CombinedDataset(CachedDataset2):
|
|
|
1294
1344
|
complete_fracs_and_ds_idx = [
|
|
1295
1345
|
(
|
|
1296
1346
|
self.datasets[self.dataset_idx2key_map[j]].get_complete_frac(
|
|
1297
|
-
self.used_num_seqs_per_subset[j]
|
|
1347
|
+
self.used_num_seqs_per_subset[j], allow_only_lr_suitable=True
|
|
1348
|
+
)
|
|
1349
|
+
if self.datasets[self.dataset_idx2key_map[j]].is_less_than_num_seqs(
|
|
1350
|
+
self.used_num_seqs_per_subset[j]
|
|
1298
1351
|
)
|
|
1299
|
-
|
|
1300
|
-
else 0.0,
|
|
1352
|
+
else float("inf"),
|
|
1301
1353
|
j,
|
|
1302
1354
|
)
|
|
1303
1355
|
for j in range(len(self.datasets))
|
|
@@ -1309,9 +1361,7 @@ class CombinedDataset(CachedDataset2):
|
|
|
1309
1361
|
# Sort by complete frac, i.e. datasets with the lowest complete frac first.
|
|
1310
1362
|
complete_fracs_and_ds_idx.sort()
|
|
1311
1363
|
for complete_frac, dataset_idx in complete_fracs_and_ds_idx:
|
|
1312
|
-
if
|
|
1313
|
-
self.used_num_seqs_per_subset[dataset_idx]
|
|
1314
|
-
):
|
|
1364
|
+
if complete_frac < float("inf"):
|
|
1315
1365
|
break
|
|
1316
1366
|
else:
|
|
1317
1367
|
return False # No dataset has remaining data
|