returnn 1.20250508.93313__py3-none-any.whl → 1.20250513.145447__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/basic.py +24 -25
- returnn/datasets/cached.py +4 -3
- returnn/datasets/distrib_files.py +1 -2
- returnn/datasets/generating.py +20 -20
- returnn/datasets/hdf.py +9 -9
- returnn/datasets/lm.py +25 -13
- returnn/datasets/meta.py +39 -38
- returnn/datasets/normalization_data.py +1 -1
- returnn/datasets/postprocessing.py +20 -13
- returnn/datasets/sprint.py +8 -7
- returnn/datasets/util/strings.py +0 -1
- returnn/datasets/util/vocabulary.py +3 -3
- returnn/extern/graph_editor/subgraph.py +1 -2
- returnn/extern/graph_editor/transform.py +1 -2
- returnn/extern/graph_editor/util.py +1 -2
- returnn/frontend/_backend.py +4 -3
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/audio/mel.py +0 -1
- returnn/frontend/const.py +3 -3
- returnn/frontend/device.py +0 -1
- returnn/frontend/dropout.py +1 -1
- returnn/frontend/encoder/e_branchformer.py +1 -1
- returnn/frontend/loop.py +3 -3
- returnn/frontend/loss.py +0 -1
- returnn/frontend/matmul.py +0 -1
- returnn/frontend/run_ctx.py +9 -9
- returnn/frontend/signal.py +0 -1
- returnn/frontend/types.py +2 -4
- returnn/native_op.py +13 -0
- returnn/sprint/cache.py +2 -4
- returnn/sprint/interface.py +3 -4
- returnn/tensor/_dim_extra.py +9 -9
- returnn/tensor/_tensor_extra.py +20 -19
- returnn/tensor/_tensor_op_overloads.py +0 -1
- returnn/tensor/tensor.py +1 -1
- returnn/tensor/tensor_dict.py +9 -9
- returnn/tf/engine.py +60 -65
- returnn/tf/frontend_layers/_backend.py +3 -3
- returnn/tf/frontend_layers/cond.py +6 -6
- returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
- returnn/tf/frontend_layers/layer.py +12 -12
- returnn/tf/frontend_layers/loop.py +3 -3
- returnn/tf/frontend_layers/make_layer.py +0 -1
- returnn/tf/layers/base.py +56 -49
- returnn/tf/layers/basic.py +60 -65
- returnn/tf/layers/rec.py +74 -74
- returnn/tf/native_op.py +1 -3
- returnn/tf/network.py +60 -57
- returnn/tf/updater.py +3 -3
- returnn/tf/util/basic.py +24 -23
- returnn/torch/data/extern_data.py +4 -5
- returnn/torch/data/pipeline.py +3 -4
- returnn/torch/engine.py +16 -16
- returnn/torch/frontend/_backend.py +15 -15
- returnn/torch/frontend/bridge.py +3 -3
- returnn/torch/updater.py +8 -9
- returnn/torch/util/debug_inf_nan.py +0 -2
- returnn/torch/util/exception_helper.py +1 -1
- returnn/torch/util/scaled_gradient.py +0 -1
- returnn/util/basic.py +1 -2
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/METADATA +1 -1
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/RECORD +67 -67
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/LICENSE +0 -0
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/WHEEL +0 -0
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250513.145447'
|
|
2
|
+
long_version = '1.20250513.145447+git.9cdc2a4'
|
returnn/datasets/basic.py
CHANGED
|
@@ -20,7 +20,7 @@ import math
|
|
|
20
20
|
import numpy
|
|
21
21
|
import functools
|
|
22
22
|
import typing
|
|
23
|
-
from typing import TYPE_CHECKING, Optional, Any, Union, Type, Dict, Sequence, List, Callable
|
|
23
|
+
from typing import TYPE_CHECKING, Optional, Any, Set, Tuple, Union, Type, Dict, Sequence, List, Callable
|
|
24
24
|
|
|
25
25
|
from returnn.log import log
|
|
26
26
|
from returnn.engine.batch import Batch, BatchSetGenerator
|
|
@@ -141,12 +141,10 @@ class Dataset:
|
|
|
141
141
|
:param int _shard_index: local shard index, when sharding is enabled
|
|
142
142
|
"""
|
|
143
143
|
self.name = name or ("dataset_id%s" % id(self))
|
|
144
|
-
self.lock
|
|
145
|
-
self.rnd_seq_drop
|
|
144
|
+
self.lock: Optional[RLock] = None # Used when manipulating our data potentially from multiple threads.
|
|
145
|
+
self.rnd_seq_drop: Optional[Random] = None
|
|
146
146
|
self.num_inputs = 0 # usually not used, but num_outputs instead, which is more generic
|
|
147
|
-
self.num_outputs = (
|
|
148
|
-
None
|
|
149
|
-
) # type: typing.Optional[typing.Dict[str,typing.Tuple[int,int]]] # tuple is num-classes, len(shape). # nopep8
|
|
147
|
+
self.num_outputs: Optional[Dict[str, Tuple[int, int]]] = None # tuple is num-classes, len(shape).
|
|
150
148
|
self.window = window
|
|
151
149
|
self.seq_ordering = seq_ordering # "default", "sorted" or "random". See self.get_seq_order_for_epoch().
|
|
152
150
|
self.fixed_random_seed = fixed_random_seed
|
|
@@ -159,10 +157,10 @@ class Dataset:
|
|
|
159
157
|
self._seq_order_seq_lens_file = seq_order_seq_lens_file
|
|
160
158
|
self._seq_order_seq_lens_by_idx = None
|
|
161
159
|
# There is probably no use case for combining the two, so avoid potential misconfiguration.
|
|
162
|
-
assert (
|
|
163
|
-
|
|
164
|
-
)
|
|
165
|
-
self.labels
|
|
160
|
+
assert self.partition_epoch == 1 or self.repeat_epoch == 1, (
|
|
161
|
+
"Combining partition_epoch and repeat_epoch is prohibited."
|
|
162
|
+
)
|
|
163
|
+
self.labels: Dict[str, List[str]] = {}
|
|
166
164
|
self.weights = {}
|
|
167
165
|
self._num_timesteps = 0
|
|
168
166
|
self._num_seqs = 0
|
|
@@ -213,8 +211,8 @@ class Dataset:
|
|
|
213
211
|
getattr(self, "epoch", "<unknown>"),
|
|
214
212
|
)
|
|
215
213
|
|
|
216
|
-
_getnewargs_exclude_attrs = set()
|
|
217
|
-
_getnewargs_remap
|
|
214
|
+
_getnewargs_exclude_attrs: Set[str] = set()
|
|
215
|
+
_getnewargs_remap: Dict[str, str] = {}
|
|
218
216
|
|
|
219
217
|
@staticmethod
|
|
220
218
|
def _create_from_reduce(cls, kwargs, state) -> Dataset:
|
|
@@ -660,12 +658,13 @@ class Dataset:
|
|
|
660
658
|
)
|
|
661
659
|
old_seq_index = seq_index
|
|
662
660
|
seq_index = [i for i in seq_index if all_seq_tags[i] in self.seq_tags_filter]
|
|
663
|
-
assert (
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
661
|
+
assert seq_index, (
|
|
662
|
+
"%s: empty after applying seq_list_filter_file. Example filter tags: %r, used tags: %r"
|
|
663
|
+
% (
|
|
664
|
+
self,
|
|
665
|
+
sorted(self.seq_tags_filter)[:3],
|
|
666
|
+
[all_seq_tags[i] for i in old_seq_index[:3]],
|
|
667
|
+
)
|
|
669
668
|
)
|
|
670
669
|
return seq_index
|
|
671
670
|
|
|
@@ -736,9 +735,9 @@ class Dataset:
|
|
|
736
735
|
"""
|
|
737
736
|
self.epoch = epoch
|
|
738
737
|
self.rnd_seq_drop = Random(self._get_random_seed_for_epoch(epoch=epoch))
|
|
739
|
-
assert (
|
|
740
|
-
self
|
|
741
|
-
)
|
|
738
|
+
assert self._num_shards == 1 or self.supports_sharding(), (
|
|
739
|
+
f"{self}: does not support sharding, but got num_shards == {self._num_shards}"
|
|
740
|
+
)
|
|
742
741
|
return False
|
|
743
742
|
|
|
744
743
|
def finish_epoch(self, *, free_resources: bool = False):
|
|
@@ -970,16 +969,16 @@ class Dataset:
|
|
|
970
969
|
except Exception: # also not always available
|
|
971
970
|
num_seqs = None # ignore
|
|
972
971
|
|
|
973
|
-
if math.isinf(num_seqs):
|
|
972
|
+
if num_seqs is not None and math.isinf(num_seqs):
|
|
974
973
|
if allow_only_lr_suitable:
|
|
975
974
|
# cannot compute meaningful complete_frac for infinite num_seqs
|
|
976
975
|
return None
|
|
977
976
|
else:
|
|
978
977
|
num_seqs = None
|
|
979
978
|
|
|
980
|
-
assert (
|
|
981
|
-
|
|
982
|
-
)
|
|
979
|
+
assert num_seqs is None or 0 <= sorted_seq_idx < num_seqs, (
|
|
980
|
+
f"{self}: invalid seq indices: 0 <= seq_idx ({sorted_seq_idx}) < num_seqs ({num_seqs}) violated"
|
|
981
|
+
)
|
|
983
982
|
return self.generic_complete_frac(sorted_seq_idx, num_seqs)
|
|
984
983
|
|
|
985
984
|
@property
|
returnn/datasets/cached.py
CHANGED
|
@@ -46,9 +46,10 @@ class CachedDataset(Dataset):
|
|
|
46
46
|
self._index_map = range(len(self._seq_index)) # sorted seq idx -> seq_index idx
|
|
47
47
|
self._tag_idx = {} # type: typing.Dict[str,int] # map of tag -> real-seq-idx. call _update_tag_idx
|
|
48
48
|
self.targets = {}
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
# the keys for which we provide data;
|
|
50
|
+
# we may have labels for additional keys in self.labels
|
|
51
|
+
self.target_keys = []
|
|
52
|
+
|
|
52
53
|
self.timestamps = None
|
|
53
54
|
|
|
54
55
|
def initialize(self):
|
|
@@ -451,8 +451,7 @@ class DistributeFilesDataset(CachedDataset2):
|
|
|
451
451
|
# We need to decide where to add this file, to the current or the next sub epoch.
|
|
452
452
|
if not files_per_bin[bin_idx] or (
|
|
453
453
|
# Better to add this file to the current sub epoch?
|
|
454
|
-
abs((size_taken + size) - avg_size_per_sub_epoch)
|
|
455
|
-
<= abs(size_taken - avg_size_per_sub_epoch)
|
|
454
|
+
abs((size_taken + size) - avg_size_per_sub_epoch) <= abs(size_taken - avg_size_per_sub_epoch)
|
|
456
455
|
):
|
|
457
456
|
files_per_bin[bin_idx].append(f_tree)
|
|
458
457
|
size_taken = 0
|
returnn/datasets/generating.py
CHANGED
|
@@ -46,12 +46,12 @@ class GeneratingDataset(Dataset):
|
|
|
46
46
|
output_dim["data"] = (input_dim * self.window, 2) # not sparse
|
|
47
47
|
self.num_outputs = output_dim
|
|
48
48
|
self.expected_load_seq_start = 0
|
|
49
|
-
self._seq_order
|
|
49
|
+
self._seq_order: Optional[Sequence[int]] = None
|
|
50
50
|
self._num_seqs = num_seqs
|
|
51
51
|
self._total_num_seqs = num_seqs
|
|
52
52
|
self.random = numpy.random.RandomState(1)
|
|
53
53
|
self.reached_final_seq = False
|
|
54
|
-
self.added_data
|
|
54
|
+
self.added_data: List[DatasetSeq] = []
|
|
55
55
|
if self.seq_ordering in ("sorted", "sorted_reverse"):
|
|
56
56
|
# For the dev/eval dataset, RETURNN automatically tries to sort them.
|
|
57
57
|
# As this is not supported, just ignore it and reset it to the default order.
|
|
@@ -904,22 +904,24 @@ class DummyDatasetMultipleDataKeys(DummyDataset):
|
|
|
904
904
|
seq_len = {}
|
|
905
905
|
for key in self.data_keys:
|
|
906
906
|
seq_len[key] = _seq_len
|
|
907
|
-
assert set(data_keys) == set(
|
|
908
|
-
seq_len
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
907
|
+
assert set(data_keys) == set(seq_len.keys()), (
|
|
908
|
+
"%s: the keys of seq_len (%s) must match the keys in data_keys=%s."
|
|
909
|
+
% (
|
|
910
|
+
self,
|
|
911
|
+
str(seq_len.keys()),
|
|
912
|
+
str(data_keys),
|
|
913
|
+
)
|
|
914
|
+
)
|
|
915
|
+
assert isinstance(output_dim, dict), (
|
|
916
|
+
"%s: output_dim %r must be a dict containing a definition for each key in data_keys." % (self, output_dim)
|
|
913
917
|
)
|
|
914
|
-
assert
|
|
915
|
-
output_dim
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
str(output_dim.keys()),
|
|
922
|
-
str(data_keys),
|
|
918
|
+
assert set(data_keys) == set(output_dim.keys()), (
|
|
919
|
+
"%s: the keys of output_dim (%s) must match the keys in data_keys=%s."
|
|
920
|
+
% (
|
|
921
|
+
self,
|
|
922
|
+
str(output_dim.keys()),
|
|
923
|
+
str(data_keys),
|
|
924
|
+
)
|
|
923
925
|
)
|
|
924
926
|
|
|
925
927
|
super(DummyDatasetMultipleDataKeys, self).__init__(
|
|
@@ -2134,9 +2136,7 @@ class LibriSpeechCorpus(CachedDataset2):
|
|
|
2134
2136
|
import os
|
|
2135
2137
|
import zipfile
|
|
2136
2138
|
|
|
2137
|
-
transs = (
|
|
2138
|
-
{}
|
|
2139
|
-
) # type: typing.Dict[typing.Tuple[str,int,int,int],str] # (subdir, speaker-id, chapter-id, seq-id) -> transcription # nopep8
|
|
2139
|
+
transs: Dict[Tuple[str, int, int, int], str] = {} # (subdir, speaker-id, chapter-id, seq-id) -> transcription
|
|
2140
2140
|
if self.use_zip:
|
|
2141
2141
|
for name, zip_file in self._zip_files.items():
|
|
2142
2142
|
assert isinstance(zip_file, zipfile.ZipFile)
|
returnn/datasets/hdf.py
CHANGED
|
@@ -37,9 +37,9 @@ class HDFDataset(CachedDataset):
|
|
|
37
37
|
:param bool use_cache_manager: uses :func:`Util.cf` for files
|
|
38
38
|
"""
|
|
39
39
|
super(HDFDataset, self).__init__(**kwargs)
|
|
40
|
-
assert (
|
|
41
|
-
|
|
42
|
-
)
|
|
40
|
+
assert self.partition_epoch == 1 or self.cache_byte_size_total_limit == 0, (
|
|
41
|
+
"To use partition_epoch in HDFDatasets, disable caching by setting cache_byte_size=0"
|
|
42
|
+
)
|
|
43
43
|
self._use_cache_manager = use_cache_manager
|
|
44
44
|
self.files = [] # type: typing.List[str] # file names
|
|
45
45
|
self.h5_files = [] # type: typing.List[h5py.File]
|
|
@@ -1246,9 +1246,9 @@ class SimpleHDFWriter:
|
|
|
1246
1246
|
self._datasets[name].resize(old_shape[0] + raw_data.shape[0], axis=0)
|
|
1247
1247
|
expected_shape = (raw_data.shape[0],) + old_shape[1:]
|
|
1248
1248
|
# append raw data to dataset
|
|
1249
|
-
assert (
|
|
1250
|
-
expected_shape
|
|
1251
|
-
)
|
|
1249
|
+
assert expected_shape == raw_data.shape, (
|
|
1250
|
+
f"{self} insert: shape mismatch: expected {expected_shape}, got {raw_data.shape}"
|
|
1251
|
+
)
|
|
1252
1252
|
self._datasets[name][self._file.attrs["numTimesteps"] :] = raw_data
|
|
1253
1253
|
self._file.attrs["numTimesteps"] += raw_data.shape[0]
|
|
1254
1254
|
self._file.attrs["numSeqs"] += 1
|
|
@@ -1302,9 +1302,9 @@ class SimpleHDFWriter:
|
|
|
1302
1302
|
|
|
1303
1303
|
offset = self._extra_num_time_steps[data_key] - raw_data.shape[0]
|
|
1304
1304
|
expected_shape = (raw_data.shape[0],) + hdf_data.shape[1:]
|
|
1305
|
-
assert (
|
|
1306
|
-
expected_shape
|
|
1307
|
-
)
|
|
1305
|
+
assert expected_shape == raw_data.shape, (
|
|
1306
|
+
f"{self} insert other {data_key!r}: shape mismatch: expected {expected_shape}, got {raw_data.shape}"
|
|
1307
|
+
)
|
|
1308
1308
|
hdf_data[offset:] = raw_data
|
|
1309
1309
|
|
|
1310
1310
|
def insert_batch(self, inputs, seq_len, seq_tag, extra=None):
|
returnn/datasets/lm.py
CHANGED
|
@@ -7,7 +7,22 @@ and some related helpers.
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import (
|
|
11
|
+
Iterable,
|
|
12
|
+
Optional,
|
|
13
|
+
Sequence,
|
|
14
|
+
Union,
|
|
15
|
+
Any,
|
|
16
|
+
Callable,
|
|
17
|
+
Iterator,
|
|
18
|
+
List,
|
|
19
|
+
Tuple,
|
|
20
|
+
Set,
|
|
21
|
+
BinaryIO,
|
|
22
|
+
Dict,
|
|
23
|
+
cast,
|
|
24
|
+
Generator,
|
|
25
|
+
)
|
|
11
26
|
import typing
|
|
12
27
|
import os
|
|
13
28
|
from io import IOBase
|
|
@@ -1472,8 +1487,8 @@ class TranslationDataset(CachedDataset2):
|
|
|
1472
1487
|
}
|
|
1473
1488
|
|
|
1474
1489
|
self._data_keys = self._source_data_keys + self._target_data_keys
|
|
1475
|
-
self._data = {data_key: [] for data_key in self._data_keys}
|
|
1476
|
-
self._data_len
|
|
1490
|
+
self._data: Dict[str, List[numpy.ndarray]] = {data_key: [] for data_key in self._data_keys}
|
|
1491
|
+
self._data_len: Optional[int] = None
|
|
1477
1492
|
|
|
1478
1493
|
self._vocabs = self._get_vocabs()
|
|
1479
1494
|
self.num_outputs = {k: [max(self._vocabs[k].values()) + 1, 1] for k in self._vocabs.keys()} # all sparse
|
|
@@ -1489,7 +1504,7 @@ class TranslationDataset(CachedDataset2):
|
|
|
1489
1504
|
unknown_label.setdefault(data_key, None)
|
|
1490
1505
|
self._unknown_label = unknown_label
|
|
1491
1506
|
|
|
1492
|
-
self._seq_order
|
|
1507
|
+
self._seq_order: Optional[Sequence[int]] = None # seq_idx -> line_nr
|
|
1493
1508
|
self._tag_prefix = "line-" # sequence tag is "line-n", where n is the line number
|
|
1494
1509
|
self._thread = Thread(name="%r reader" % self, target=self._thread_main)
|
|
1495
1510
|
self._thread.daemon = True
|
|
@@ -1878,14 +1893,11 @@ class TranslationFactorsDataset(TranslationDataset):
|
|
|
1878
1893
|
assert file_prefix == self.target_file_prefix
|
|
1879
1894
|
data_keys = self._target_data_keys
|
|
1880
1895
|
|
|
1881
|
-
data = [
|
|
1896
|
+
data: List[List[numpy.ndarray]] = [
|
|
1882
1897
|
self._factored_words_to_numpy(data_keys, s.decode("utf8").strip().split(), self._add_postfix[file_prefix])
|
|
1883
1898
|
for s in data_strs
|
|
1884
|
-
] #
|
|
1885
|
-
|
|
1886
|
-
data = zip(
|
|
1887
|
-
*data
|
|
1888
|
-
) # type: typing.Iterable[typing.Tuple[numpy.ndarray]] # shape: (len(data_keys), len(data_strs))
|
|
1899
|
+
] # shape: (len(data_strs), len(data_keys))
|
|
1900
|
+
data: Iterable[Tuple[numpy.ndarray]] = zip(*data) # shape: (len(data_keys), len(data_strs))
|
|
1889
1901
|
|
|
1890
1902
|
with self._lock:
|
|
1891
1903
|
for i, data_ in enumerate(data):
|
|
@@ -1908,9 +1920,9 @@ class TranslationFactorsDataset(TranslationDataset):
|
|
|
1908
1920
|
words_per_factor = [[]] * len(data_keys)
|
|
1909
1921
|
elif len(data_keys) > 1:
|
|
1910
1922
|
factored_words = [word.split(self._factor_separator) for word in words]
|
|
1911
|
-
assert all(
|
|
1912
|
-
|
|
1913
|
-
)
|
|
1923
|
+
assert all(len(factors) == len(data_keys) for factors in factored_words), (
|
|
1924
|
+
"All words must have all factors. Expected: " + self._factor_separator.join(data_keys)
|
|
1925
|
+
)
|
|
1914
1926
|
words_per_factor = zip(*factored_words)
|
|
1915
1927
|
words_per_factor = [list(w) for w in words_per_factor]
|
|
1916
1928
|
else:
|
returnn/datasets/meta.py
CHANGED
|
@@ -247,10 +247,10 @@ class MetaDataset(CachedDataset2):
|
|
|
247
247
|
self.seq_order_control_dataset = seq_order_control_dataset
|
|
248
248
|
|
|
249
249
|
# This will only initialize datasets needed for features occuring in data_map
|
|
250
|
-
self.datasets = {
|
|
250
|
+
self.datasets: Dict[str, Dataset] = {
|
|
251
251
|
key: init_dataset(datasets[key], extra_kwargs={"name": "%s_%s" % (self.name, key)}, parent_dataset=self)
|
|
252
252
|
for key in self.dataset_keys
|
|
253
|
-
}
|
|
253
|
+
}
|
|
254
254
|
|
|
255
255
|
self._seq_list_file = seq_list_file
|
|
256
256
|
self.seq_list_original = self._load_seq_list(seq_list_file)
|
|
@@ -260,8 +260,8 @@ class MetaDataset(CachedDataset2):
|
|
|
260
260
|
|
|
261
261
|
self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
|
|
262
262
|
|
|
263
|
-
self._seq_lens
|
|
264
|
-
self._num_timesteps
|
|
263
|
+
self._seq_lens: Optional[Dict[str, NumbersDict]] = None
|
|
264
|
+
self._num_timesteps: Optional[NumbersDict] = None
|
|
265
265
|
self._seq_lens_file = seq_lens_file
|
|
266
266
|
if seq_lens_file:
|
|
267
267
|
seq_lens = load_json(filename=seq_lens_file)
|
|
@@ -290,7 +290,7 @@ class MetaDataset(CachedDataset2):
|
|
|
290
290
|
self.num_outputs = self.data_dims
|
|
291
291
|
|
|
292
292
|
self.orig_seq_order_is_initialized = False
|
|
293
|
-
self.seq_list_ordered
|
|
293
|
+
self.seq_list_ordered: Optional[Dict[str, List[str]]] = None
|
|
294
294
|
|
|
295
295
|
def _load_seq_list(self, seq_list_file: Optional[Union[str, Dict[str, str]]] = None) -> Dict[str, List[str]]:
|
|
296
296
|
"""
|
|
@@ -771,7 +771,7 @@ class ConcatDataset(CachedDataset2):
|
|
|
771
771
|
for ds in self.datasets[1:]:
|
|
772
772
|
assert ds.num_inputs == self.num_inputs
|
|
773
773
|
assert ds.num_outputs == self.num_outputs
|
|
774
|
-
self.dataset_seq_idx_offsets
|
|
774
|
+
self.dataset_seq_idx_offsets: Optional[List[int]] = None
|
|
775
775
|
|
|
776
776
|
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
|
|
777
777
|
"""
|
|
@@ -1017,9 +1017,9 @@ class CombinedDataset(CachedDataset2):
|
|
|
1017
1017
|
for (dset_key, dset_data_key), data_key in data_map.items()
|
|
1018
1018
|
}
|
|
1019
1019
|
|
|
1020
|
-
self.dataset_seq_idx_boundaries
|
|
1021
|
-
self.dataset_sorted_seq_idx_list
|
|
1022
|
-
self.used_num_seqs_per_subset
|
|
1020
|
+
self.dataset_seq_idx_boundaries: Optional[List[int]] = None
|
|
1021
|
+
self.dataset_sorted_seq_idx_list: Optional[List[Tuple[int, int]]] = None
|
|
1022
|
+
self.used_num_seqs_per_subset: Optional[List[int]] = None
|
|
1023
1023
|
|
|
1024
1024
|
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
|
|
1025
1025
|
"""
|
|
@@ -1180,9 +1180,9 @@ class CombinedDataset(CachedDataset2):
|
|
|
1180
1180
|
:rtype: list[int]
|
|
1181
1181
|
"""
|
|
1182
1182
|
assert self.partition_epoch in [None, 1], "partition_epoch not supported in combination with sampling_sizes."
|
|
1183
|
-
assert (
|
|
1184
|
-
|
|
1185
|
-
)
|
|
1183
|
+
assert self._seq_order_seq_lens_file is None, (
|
|
1184
|
+
"seq_order_seq_lens_file not supported in combination with sampling_sizes."
|
|
1185
|
+
)
|
|
1186
1186
|
assert not self.unique_seq_tags, "unique_seq_tags not supported in combination with sampling_sizes."
|
|
1187
1187
|
assert self.seq_tags_filter is None, "seq_order_seq_lens_file in combination with sampling_sizes."
|
|
1188
1188
|
|
|
@@ -1445,7 +1445,7 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1445
1445
|
self.repeat_in_between_last_frame_up_to_multiple_of = repeat_in_between_last_frame_up_to_multiple_of or {}
|
|
1446
1446
|
self.pad_narrow_data_to_multiple_of_target_len = pad_narrow_data_to_multiple_of_target_len or {}
|
|
1447
1447
|
if epoch_wise_filter is None:
|
|
1448
|
-
self.epoch_wise_filter
|
|
1448
|
+
self.epoch_wise_filter: Optional[EpochWiseFilter] = None
|
|
1449
1449
|
elif isinstance(epoch_wise_filter, dict):
|
|
1450
1450
|
self.epoch_wise_filter = EpochWiseFilter(epoch_wise_filter)
|
|
1451
1451
|
else:
|
|
@@ -1471,10 +1471,8 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1471
1471
|
self.seq_lens = eval(open(seq_len_file).read())
|
|
1472
1472
|
assert isinstance(self.seq_lens, dict)
|
|
1473
1473
|
self.full_seq_len_list = self._get_full_seq_lens_list()
|
|
1474
|
-
self.cur_seq_list
|
|
1475
|
-
self.cur_sub_seq_idxs =
|
|
1476
|
-
None
|
|
1477
|
-
) # type: typing.Optional[typing.List[typing.List[int]]] # list of list of sub seq idxs
|
|
1474
|
+
self.cur_seq_list: typing.Optional[typing.List[str]] = None # list of seq tags
|
|
1475
|
+
self.cur_sub_seq_idxs: typing.Optional[typing.List[typing.List[int]]] = None # list of list of sub seq idxs
|
|
1478
1476
|
|
|
1479
1477
|
def _get_full_seq_lens_list(self):
|
|
1480
1478
|
"""
|
|
@@ -1564,20 +1562,22 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1564
1562
|
if seq_idx == 0: # some extra check, but enough to do for first seq only
|
|
1565
1563
|
sub_dataset_keys = self.dataset.get_data_keys()
|
|
1566
1564
|
for key in self.remove_in_between_postfix:
|
|
1567
|
-
assert (
|
|
1568
|
-
key in
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1565
|
+
assert key in sub_dataset_keys, (
|
|
1566
|
+
"%s: remove_in_between_postfix key %r not in sub dataset data-keys %r"
|
|
1567
|
+
% (
|
|
1568
|
+
self,
|
|
1569
|
+
key,
|
|
1570
|
+
sub_dataset_keys,
|
|
1571
|
+
)
|
|
1573
1572
|
)
|
|
1574
1573
|
for key in self.repeat_in_between_last_frame_up_to_multiple_of:
|
|
1575
|
-
assert (
|
|
1576
|
-
key in
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1574
|
+
assert key in sub_dataset_keys, (
|
|
1575
|
+
"%s: repeat_in_between_last_frame_up_to_multiple_of key %r not in sub dataset data-keys %r"
|
|
1576
|
+
% (
|
|
1577
|
+
self,
|
|
1578
|
+
key,
|
|
1579
|
+
sub_dataset_keys,
|
|
1580
|
+
)
|
|
1581
1581
|
)
|
|
1582
1582
|
for key in self.pad_narrow_data_to_multiple_of_target_len:
|
|
1583
1583
|
assert key in sub_dataset_keys, (
|
|
@@ -1587,15 +1587,16 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1587
1587
|
for sub_seq_idx, sub_seq_tag in zip(sub_seq_idxs, sub_seq_tags):
|
|
1588
1588
|
self.dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1)
|
|
1589
1589
|
sub_dataset_tag = self.dataset.get_tag(sub_seq_idx)
|
|
1590
|
-
assert (
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1590
|
+
assert sub_dataset_tag == sub_seq_tag, (
|
|
1591
|
+
"%s: expected tag %r for sub seq idx %i but got %r, part of seq %i %r"
|
|
1592
|
+
% (
|
|
1593
|
+
self,
|
|
1594
|
+
sub_seq_tag,
|
|
1595
|
+
sub_seq_idx,
|
|
1596
|
+
sub_dataset_tag,
|
|
1597
|
+
seq_idx,
|
|
1598
|
+
seq_tag,
|
|
1599
|
+
)
|
|
1599
1600
|
)
|
|
1600
1601
|
for key in self.get_data_keys():
|
|
1601
1602
|
data = self.dataset.get_data(sub_seq_idx, key)
|
|
@@ -169,7 +169,7 @@ class NormalizationData:
|
|
|
169
169
|
sumErr = np.sum(np.abs(newSum - oldSum - intermediateSum))
|
|
170
170
|
if sumErr > NormalizationData.SUMMATION_PRECISION:
|
|
171
171
|
raise FloatingPointError(
|
|
172
|
-
"sums have very different orders of magnitude.
|
|
172
|
+
"sums have very different orders of magnitude. summation error = {}".format(sumErr)
|
|
173
173
|
)
|
|
174
174
|
return newSum
|
|
175
175
|
|
|
@@ -308,19 +308,26 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
308
308
|
last_complete_frac = complete_frac
|
|
309
309
|
for data_key, out_t in self._out_tensor_dict_template.data.items():
|
|
310
310
|
in_t = t_dict.data[data_key]
|
|
311
|
-
assert (
|
|
312
|
-
in_t.ndim
|
|
313
|
-
|
|
314
|
-
and all(d.dimension in (d_, None) for (d, d_) in zip(in_t.dims, out_t.shape))
|
|
311
|
+
assert in_t.ndim == out_t.batch_ndim, (
|
|
312
|
+
f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. "
|
|
313
|
+
"Postprocessing data tensors must not have a batch dimension."
|
|
315
314
|
)
|
|
315
|
+
assert in_t.dtype == out_t.dtype, (
|
|
316
|
+
f"dtype mismatch for {data_key}: '{in_t.dtype}' != '{out_t.dtype}'"
|
|
317
|
+
)
|
|
318
|
+
for i, (in_dim, out_shape) in enumerate(zip(in_t.dims, out_t.shape)):
|
|
319
|
+
assert in_dim.dimension is None or in_dim.dimension == out_shape, (
|
|
320
|
+
f"Dim {i} mismatch on {data_key}: "
|
|
321
|
+
f"{in_dim.dimension} must either be `None` or equal {out_shape}"
|
|
322
|
+
)
|
|
316
323
|
yield t_dict
|
|
317
324
|
|
|
318
325
|
data_iter = self._iterate_dataset()
|
|
319
326
|
if self._map_seq_stream is not None:
|
|
320
327
|
data_iter = self._map_seq_stream(data_iter, epoch=self.epoch, rng=self._rng, **util.get_fwd_compat_kwargs())
|
|
321
|
-
assert isinstance(
|
|
322
|
-
|
|
323
|
-
)
|
|
328
|
+
assert isinstance(data_iter, Iterator), (
|
|
329
|
+
f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}"
|
|
330
|
+
)
|
|
324
331
|
return _validate_tensor_dict_iter(data_iter)
|
|
325
332
|
|
|
326
333
|
def _iterate_dataset(self) -> Iterator[TensorDict]:
|
|
@@ -349,9 +356,9 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
349
356
|
tensor_dict = self._map_seq(
|
|
350
357
|
tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs()
|
|
351
358
|
)
|
|
352
|
-
assert isinstance(
|
|
353
|
-
|
|
354
|
-
)
|
|
359
|
+
assert isinstance(tensor_dict, TensorDict), (
|
|
360
|
+
f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
|
|
361
|
+
)
|
|
355
362
|
|
|
356
363
|
# Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped
|
|
357
364
|
# since we don't add/drop any segments w/ the non-iterator postprocessing function.
|
|
@@ -367,9 +374,9 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
367
374
|
if self._seq_list_for_validation is not None:
|
|
368
375
|
seq_tag = self._seq_list_for_validation[seq_index]
|
|
369
376
|
tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item()
|
|
370
|
-
assert (
|
|
371
|
-
tag_of_seq
|
|
372
|
-
)
|
|
377
|
+
assert tag_of_seq == seq_tag, (
|
|
378
|
+
f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given"
|
|
379
|
+
)
|
|
373
380
|
|
|
374
381
|
yield tensor_dict
|
|
375
382
|
seq_index += 1
|
returnn/datasets/sprint.py
CHANGED
|
@@ -393,13 +393,14 @@ class SprintDatasetBase(Dataset):
|
|
|
393
393
|
targets = {"classes": targets}
|
|
394
394
|
if "classes" in targets:
|
|
395
395
|
# 'classes' is always the alignment
|
|
396
|
-
assert targets["classes"].shape == (
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
396
|
+
assert targets["classes"].shape == (reduce_num_frames,), (
|
|
397
|
+
"Number of targets %s does not match number of features %s (reduce factor %d)"
|
|
398
|
+
% (
|
|
399
|
+
# is in format (time,)
|
|
400
|
+
targets["classes"].shape,
|
|
401
|
+
(num_frames,),
|
|
402
|
+
self.reduce_target_factor,
|
|
403
|
+
)
|
|
403
404
|
)
|
|
404
405
|
if "speaker_name" in targets:
|
|
405
406
|
targets["speaker_name"] = targets["speaker_name"].strip()
|
returnn/datasets/util/strings.py
CHANGED
|
@@ -185,9 +185,9 @@ class Vocabulary:
|
|
|
185
185
|
labels = file_content.splitlines()
|
|
186
186
|
labels_from_idx = {i: line for (i, line) in enumerate(labels)}
|
|
187
187
|
labels_to_idx = {line: i for (i, line) in enumerate(labels)}
|
|
188
|
-
assert isinstance(
|
|
189
|
-
|
|
190
|
-
)
|
|
188
|
+
assert isinstance(labels_to_idx, dict), (
|
|
189
|
+
f"{self}: expected dict, got {type(labels_to_idx).__name__} in {filename}"
|
|
190
|
+
)
|
|
191
191
|
if labels_from_idx is None:
|
|
192
192
|
labels_from_idx = {idx: label for (label, idx) in sorted(labels_to_idx.items())}
|
|
193
193
|
min_label, max_label, num_labels = min(labels_from_idx), max(labels_from_idx), len(labels_from_idx)
|
|
@@ -12,8 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""SubGraphView: a subgraph view on an existing tf.Graph.
|
|
16
|
-
"""
|
|
15
|
+
"""SubGraphView: a subgraph view on an existing tf.Graph."""
|
|
17
16
|
|
|
18
17
|
from __future__ import annotations
|
|
19
18
|
|
|
@@ -12,8 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Class to transform an subgraph into another.
|
|
16
|
-
"""
|
|
15
|
+
"""Class to transform an subgraph into another."""
|
|
17
16
|
|
|
18
17
|
from __future__ import annotations
|
|
19
18
|
|
|
@@ -12,8 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Utility functions for the graph_editor.
|
|
16
|
-
"""
|
|
15
|
+
"""Utility functions for the graph_editor."""
|
|
17
16
|
|
|
18
17
|
from __future__ import annotations
|
|
19
18
|
|
returnn/frontend/_backend.py
CHANGED
|
@@ -1509,9 +1509,10 @@ def get_backend_by_raw_tensor_type(tensor_type: Type[T]) -> Union[Type[Backend[T
|
|
|
1509
1509
|
else:
|
|
1510
1510
|
continue
|
|
1511
1511
|
|
|
1512
|
-
assert any(
|
|
1513
|
-
|
|
1514
|
-
|
|
1512
|
+
assert any(issubclass(base_type, type_) for type_ in tensor_types), (
|
|
1513
|
+
f"tensor type {tensor_type} base_type {base_type} not in {tensor_types}, "
|
|
1514
|
+
f"expected for backend {backend_type}"
|
|
1515
|
+
)
|
|
1515
1516
|
for base_type_ in tensor_types:
|
|
1516
1517
|
register_backend_by_tensor_type(base_type_, backend_type)
|
|
1517
1518
|
return backend_type
|