returnn 1.20250207.143045__py3-none-any.whl → 1.20250211.210150__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +112 -88
- returnn/torch/engine.py +13 -1
- returnn/torch/frontend/bridge.py +21 -12
- {returnn-1.20250207.143045.dist-info → returnn-1.20250211.210150.dist-info}/METADATA +1 -1
- {returnn-1.20250207.143045.dist-info → returnn-1.20250211.210150.dist-info}/RECORD +10 -10
- {returnn-1.20250207.143045.dist-info → returnn-1.20250211.210150.dist-info}/LICENSE +0 -0
- {returnn-1.20250207.143045.dist-info → returnn-1.20250211.210150.dist-info}/WHEEL +0 -0
- {returnn-1.20250207.143045.dist-info → returnn-1.20250211.210150.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250211.210150'
|
|
2
|
+
long_version = '1.20250211.210150+git.074b83c'
|
returnn/datasets/lm.py
CHANGED
|
@@ -7,9 +7,10 @@ and some related helpers.
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
from typing import Optional, Union, Callable, Iterator, List, Tuple, BinaryIO, cast
|
|
10
|
+
from typing import Optional, Union, Any, Callable, Iterator, List, Tuple, Set, BinaryIO, Dict, cast, Generator
|
|
11
11
|
import typing
|
|
12
12
|
import os
|
|
13
|
+
from io import IOBase
|
|
13
14
|
import sys
|
|
14
15
|
import time
|
|
15
16
|
import re
|
|
@@ -1043,17 +1044,17 @@ class Lexicon:
|
|
|
1043
1044
|
Lexicon. Map of words to phoneme sequences (can have multiple pronunciations).
|
|
1044
1045
|
"""
|
|
1045
1046
|
|
|
1046
|
-
def __init__(self, filename):
|
|
1047
|
+
def __init__(self, filename: str):
|
|
1047
1048
|
"""
|
|
1048
|
-
:param
|
|
1049
|
+
:param filename:
|
|
1049
1050
|
"""
|
|
1050
1051
|
print("Loading lexicon", filename, file=log.v4)
|
|
1051
1052
|
lex_file = open(filename, "rb")
|
|
1052
1053
|
if filename.endswith(".gz"):
|
|
1053
1054
|
lex_file = gzip.GzipFile(fileobj=lex_file)
|
|
1054
|
-
self.phoneme_list
|
|
1055
|
-
self.phonemes
|
|
1056
|
-
self.lemmas
|
|
1055
|
+
self.phoneme_list: List[str] = []
|
|
1056
|
+
self.phonemes: Dict[str, Dict[str, Any]] = {} # phone -> {index, symbol, variation}
|
|
1057
|
+
self.lemmas: Dict[str, Dict[str, Any]] = {} # orth -> {orth, phons}
|
|
1057
1058
|
|
|
1058
1059
|
context = iter(ElementTree.iterparse(lex_file, events=("start", "end")))
|
|
1059
1060
|
_, root = next(context) # get root element
|
|
@@ -1086,8 +1087,17 @@ class Lexicon:
|
|
|
1086
1087
|
{"phon": e.text.strip(), "score": float(e.attrib.get("score", 0))}
|
|
1087
1088
|
for e in elem.findall("phon")
|
|
1088
1089
|
]
|
|
1089
|
-
|
|
1090
|
-
self.lemmas
|
|
1090
|
+
lemma = {"orth": orth, "phons": phons}
|
|
1091
|
+
if orth in self.lemmas: # unexpected, already exists?
|
|
1092
|
+
if self.lemmas[orth] == lemma:
|
|
1093
|
+
print(f"Warning: lemma {lemma} duplicated in lexicon {filename}", file=log.v4)
|
|
1094
|
+
else:
|
|
1095
|
+
raise Exception(
|
|
1096
|
+
f"orth {orth!r} lemma duplicated in lexicon {filename}."
|
|
1097
|
+
f" old: {self.lemmas[orth]}, new: {lemma}"
|
|
1098
|
+
)
|
|
1099
|
+
else: # lemma does not exist yet -- this is the expected case
|
|
1100
|
+
self.lemmas[orth] = lemma
|
|
1091
1101
|
root.clear() # free memory
|
|
1092
1102
|
print("Finished whole lexicon, %i lemmas" % len(self.lemmas), file=log.v4)
|
|
1093
1103
|
|
|
@@ -1097,12 +1107,12 @@ class StateTying:
|
|
|
1097
1107
|
Clustering of (allophone) states into classes.
|
|
1098
1108
|
"""
|
|
1099
1109
|
|
|
1100
|
-
def __init__(self, state_tying_file):
|
|
1110
|
+
def __init__(self, state_tying_file: str):
|
|
1101
1111
|
"""
|
|
1102
|
-
:param
|
|
1112
|
+
:param state_tying_file:
|
|
1103
1113
|
"""
|
|
1104
|
-
self.allo_map = {} # allophone-state-str -> class-idx
|
|
1105
|
-
self.class_map = {} # class-idx -> set(allophone-state-str)
|
|
1114
|
+
self.allo_map: Dict[str, int] = {} # allophone-state-str -> class-idx
|
|
1115
|
+
self.class_map: Dict[int, Set[str]] = {} # class-idx -> set(allophone-state-str)
|
|
1106
1116
|
lines = open(state_tying_file).read().splitlines()
|
|
1107
1117
|
for line in lines:
|
|
1108
1118
|
allo_str, class_idx_str = line.split()
|
|
@@ -1124,29 +1134,45 @@ class PhoneSeqGenerator:
|
|
|
1124
1134
|
|
|
1125
1135
|
def __init__(
|
|
1126
1136
|
self,
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1137
|
+
*,
|
|
1138
|
+
lexicon_file: str,
|
|
1139
|
+
phoneme_vocab_file: Optional[str] = None,
|
|
1140
|
+
allo_num_states: int = 3,
|
|
1141
|
+
allo_context_len: int = 1,
|
|
1142
|
+
state_tying_file: Optional[str] = None,
|
|
1143
|
+
add_silence_beginning: float = 0.1,
|
|
1144
|
+
add_silence_between_words: float = 0.1,
|
|
1145
|
+
add_silence_end: float = 0.1,
|
|
1146
|
+
repetition: float = 0.9,
|
|
1147
|
+
silence_repetition: float = 0.95,
|
|
1148
|
+
silence_lemma_orth: str = "[SILENCE]",
|
|
1149
|
+
extra_begin_lemma: Optional[Dict[str, Any]] = None,
|
|
1150
|
+
add_extra_begin_lemma: float = 1.0,
|
|
1151
|
+
extra_end_lemma: Optional[Dict[str, Any]] = None,
|
|
1152
|
+
add_extra_end_lemma: float = 1.0,
|
|
1136
1153
|
):
|
|
1137
1154
|
"""
|
|
1138
|
-
:param
|
|
1139
|
-
:param
|
|
1140
|
-
|
|
1141
|
-
:param
|
|
1142
|
-
:param
|
|
1143
|
-
:param
|
|
1144
|
-
:param
|
|
1145
|
-
:param
|
|
1146
|
-
:param
|
|
1155
|
+
:param lexicon_file: lexicon XML file
|
|
1156
|
+
:param phoneme_vocab_file: defines the vocab, label indices.
|
|
1157
|
+
If not given, automatically inferred via all (sorted) phonemes from the lexicon.
|
|
1158
|
+
:param allo_num_states: how much HMM states per allophone (all but silence)
|
|
1159
|
+
:param allo_context_len: how much context to store left and right. 1 -> triphone
|
|
1160
|
+
:param state_tying_file: for state-tying, if you want that
|
|
1161
|
+
:param add_silence_beginning: prob of adding silence at beginning
|
|
1162
|
+
:param add_silence_between_words: prob of adding silence between words
|
|
1163
|
+
:param add_silence_end: prob of adding silence at end
|
|
1164
|
+
:param repetition: prob of repeating an allophone
|
|
1165
|
+
:param silence_repetition: prob of repeating the silence allophone
|
|
1166
|
+
:param silence_lemma_orth: silence orth in the lexicon
|
|
1167
|
+
:param extra_begin_lemma: {"phons": [{"phon": "P1 P2 ...", ...}, ...], ...}.
|
|
1168
|
+
If given, then with prob add_extra_begin_lemma, this will be added at the beginning.
|
|
1169
|
+
:param add_extra_begin_lemma:
|
|
1170
|
+
:param extra_end_lemma: just like ``extra_begin_lemma``, but for the end
|
|
1171
|
+
:param add_extra_end_lemma:
|
|
1147
1172
|
"""
|
|
1148
1173
|
self.lexicon = Lexicon(lexicon_file)
|
|
1149
1174
|
self.phonemes = sorted(self.lexicon.phonemes.keys(), key=lambda s: self.lexicon.phonemes[s]["index"])
|
|
1175
|
+
self.phoneme_vocab = Vocabulary(phoneme_vocab_file, unknown_label=None) if phoneme_vocab_file else None
|
|
1150
1176
|
self.rnd = Random(0)
|
|
1151
1177
|
self.allo_num_states = allo_num_states
|
|
1152
1178
|
self.allo_context_len = allo_context_len
|
|
@@ -1155,40 +1181,42 @@ class PhoneSeqGenerator:
|
|
|
1155
1181
|
self.add_silence_end = add_silence_end
|
|
1156
1182
|
self.repetition = repetition
|
|
1157
1183
|
self.silence_repetition = silence_repetition
|
|
1158
|
-
self.si_lemma = self.lexicon.lemmas[
|
|
1159
|
-
self.si_phone = self.si_lemma["phons"][0]["phon"]
|
|
1160
|
-
if state_tying_file
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1184
|
+
self.si_lemma: Dict[str, Any] = self.lexicon.lemmas[silence_lemma_orth]
|
|
1185
|
+
self.si_phone: str = self.si_lemma["phons"][0]["phon"]
|
|
1186
|
+
self.state_tying = StateTying(state_tying_file) if state_tying_file else None
|
|
1187
|
+
if self.phoneme_vocab:
|
|
1188
|
+
assert not self.state_tying
|
|
1189
|
+
self.extra_begin_lemma = extra_begin_lemma
|
|
1190
|
+
self.add_extra_begin_lemma = add_extra_begin_lemma
|
|
1191
|
+
self.extra_end_lemma = extra_end_lemma
|
|
1192
|
+
self.add_extra_end_lemma = add_extra_end_lemma
|
|
1193
|
+
|
|
1194
|
+
def random_seed(self, seed: int):
|
|
1195
|
+
"""Reset RNG via given seed"""
|
|
1169
1196
|
self.rnd.seed(seed)
|
|
1170
1197
|
|
|
1171
|
-
def get_class_labels(self):
|
|
1172
|
-
"""
|
|
1173
|
-
:
|
|
1174
|
-
|
|
1175
|
-
|
|
1198
|
+
def get_class_labels(self) -> List[str]:
|
|
1199
|
+
""":return: class labels"""
|
|
1200
|
+
if self.phoneme_vocab:
|
|
1201
|
+
return self.phoneme_vocab.labels
|
|
1202
|
+
elif self.state_tying:
|
|
1176
1203
|
# State tying labels. Represented by some allophone state str.
|
|
1177
1204
|
return ["|".join(sorted(self.state_tying.class_map[i])) for i in range(self.state_tying.num_classes)]
|
|
1178
1205
|
else:
|
|
1179
1206
|
# The phonemes are the labels.
|
|
1180
1207
|
return self.phonemes
|
|
1181
1208
|
|
|
1182
|
-
def seq_to_class_idxs(self, phones, dtype=None):
|
|
1209
|
+
def seq_to_class_idxs(self, phones: List[AllophoneState], dtype: Optional[str] = None) -> numpy.ndarray:
|
|
1183
1210
|
"""
|
|
1184
|
-
:param
|
|
1185
|
-
:param
|
|
1186
|
-
:
|
|
1187
|
-
:returns 1D numpy array with the indices
|
|
1211
|
+
:param phones: list of allophone states
|
|
1212
|
+
:param dtype: eg "int32". "int32" by default
|
|
1213
|
+
:returns: 1D numpy array with the indices
|
|
1188
1214
|
"""
|
|
1189
1215
|
if dtype is None:
|
|
1190
1216
|
dtype = "int32"
|
|
1191
|
-
if self.
|
|
1217
|
+
if self.phoneme_vocab:
|
|
1218
|
+
return numpy.array([self.phoneme_vocab.label_to_id(a.id) for a in phones], dtype=dtype)
|
|
1219
|
+
elif self.state_tying:
|
|
1192
1220
|
# State tying indices.
|
|
1193
1221
|
return numpy.array([self.state_tying.allo_map[a.format()] for a in phones], dtype=dtype)
|
|
1194
1222
|
else:
|
|
@@ -1196,11 +1224,9 @@ class PhoneSeqGenerator:
|
|
|
1196
1224
|
# It should not happen that we don't have some phoneme. The lexicon should not be inconsistent.
|
|
1197
1225
|
return numpy.array([self.lexicon.phonemes[p.id]["index"] for p in phones], dtype=dtype)
|
|
1198
1226
|
|
|
1199
|
-
def
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
:rtype: typing.Iterator[typing.Dict[str]]
|
|
1203
|
-
"""
|
|
1227
|
+
def _iter_orth_lemmas(self, orth: str) -> Generator[Dict[str, Any], None, None]:
|
|
1228
|
+
if self.extra_begin_lemma and self.rnd.random() < self.add_extra_begin_lemma:
|
|
1229
|
+
yield self.extra_begin_lemma
|
|
1204
1230
|
if self.rnd.random() < self.add_silence_beginning:
|
|
1205
1231
|
yield self.si_lemma
|
|
1206
1232
|
symbols = list(orth.split())
|
|
@@ -1224,26 +1250,25 @@ class PhoneSeqGenerator:
|
|
|
1224
1250
|
yield self.si_lemma
|
|
1225
1251
|
if self.rnd.random() < self.add_silence_end:
|
|
1226
1252
|
yield self.si_lemma
|
|
1253
|
+
if self.extra_end_lemma and self.rnd.random() < self.add_extra_end_lemma:
|
|
1254
|
+
yield self.extra_end_lemma
|
|
1227
1255
|
|
|
1228
|
-
def orth_to_phones(self, orth):
|
|
1229
|
-
"""
|
|
1230
|
-
:param str orth:
|
|
1231
|
-
:rtype: str
|
|
1232
|
-
"""
|
|
1256
|
+
def orth_to_phones(self, orth: str) -> str:
|
|
1257
|
+
""":return: space-separated phones"""
|
|
1233
1258
|
phones = []
|
|
1234
|
-
for lemma in self.
|
|
1259
|
+
for lemma in self._iter_orth_lemmas(orth):
|
|
1235
1260
|
phon = self.rnd.choice(lemma["phons"])
|
|
1236
|
-
phones
|
|
1261
|
+
phones.append(phon["phon"])
|
|
1237
1262
|
return " ".join(phones)
|
|
1238
1263
|
|
|
1239
1264
|
# noinspection PyMethodMayBeStatic
|
|
1240
|
-
def _phones_to_allos(self, phones):
|
|
1265
|
+
def _phones_to_allos(self, phones: Iterator[str]) -> Generator[AllophoneState, None, None]:
|
|
1241
1266
|
for p in phones:
|
|
1242
1267
|
a = AllophoneState()
|
|
1243
1268
|
a.id = p
|
|
1244
1269
|
yield a
|
|
1245
1270
|
|
|
1246
|
-
def _random_allo_silence(self, phone=None):
|
|
1271
|
+
def _random_allo_silence(self, phone: Optional[str] = None) -> Generator[AllophoneState, None, None]:
|
|
1247
1272
|
if phone is None:
|
|
1248
1273
|
phone = self.si_phone
|
|
1249
1274
|
while True:
|
|
@@ -1256,7 +1281,7 @@ class PhoneSeqGenerator:
|
|
|
1256
1281
|
if self.rnd.random() >= self.silence_repetition:
|
|
1257
1282
|
break
|
|
1258
1283
|
|
|
1259
|
-
def _allos_add_states(self, allos):
|
|
1284
|
+
def _allos_add_states(self, allos: Iterator[AllophoneState]) -> Generator[AllophoneState, None, None]:
|
|
1260
1285
|
for _a in allos:
|
|
1261
1286
|
if _a.id == self.si_phone:
|
|
1262
1287
|
for a in self._random_allo_silence(_a.id):
|
|
@@ -1274,9 +1299,9 @@ class PhoneSeqGenerator:
|
|
|
1274
1299
|
if self.rnd.random() >= self.repetition:
|
|
1275
1300
|
break
|
|
1276
1301
|
|
|
1277
|
-
def _allos_set_context(self, allos):
|
|
1302
|
+
def _allos_set_context(self, allos: List[AllophoneState]) -> None:
|
|
1278
1303
|
"""
|
|
1279
|
-
:param
|
|
1304
|
+
:param allos: modify inplace, ``context_history``, ``context_future``
|
|
1280
1305
|
"""
|
|
1281
1306
|
if self.allo_context_len == 0:
|
|
1282
1307
|
return
|
|
@@ -1297,15 +1322,14 @@ class PhoneSeqGenerator:
|
|
|
1297
1322
|
else:
|
|
1298
1323
|
ctx = []
|
|
1299
1324
|
|
|
1300
|
-
def generate_seq(self, orth):
|
|
1325
|
+
def generate_seq(self, orth: str) -> List[AllophoneState]:
|
|
1301
1326
|
"""
|
|
1302
|
-
:param
|
|
1303
|
-
:
|
|
1304
|
-
:returns allophone state list. those will have repetitions etc
|
|
1327
|
+
:param orth: orthography as a str. orth.split() should give words in the lexicon
|
|
1328
|
+
:returns: allophone state list. those will have repetitions etc
|
|
1305
1329
|
"""
|
|
1306
|
-
allos
|
|
1307
|
-
for lemma in self.
|
|
1308
|
-
phon = self.rnd.choice(lemma["phons"])
|
|
1330
|
+
allos: List[AllophoneState] = []
|
|
1331
|
+
for lemma in self._iter_orth_lemmas(orth):
|
|
1332
|
+
phon = self.rnd.choice(lemma["phons"]) # space-separated phones in phon["phon"]
|
|
1309
1333
|
l_allos = list(self._phones_to_allos(phon["phon"].split()))
|
|
1310
1334
|
l_allos[0].mark_initial()
|
|
1311
1335
|
l_allos[-1].mark_final()
|
|
@@ -1314,13 +1338,13 @@ class PhoneSeqGenerator:
|
|
|
1314
1338
|
allos = list(self._allos_add_states(allos))
|
|
1315
1339
|
return allos
|
|
1316
1340
|
|
|
1317
|
-
def _random_phone_seq(self, prob_add=0.8):
|
|
1341
|
+
def _random_phone_seq(self, prob_add: float = 0.8) -> Generator[str, None, None]:
|
|
1318
1342
|
while True:
|
|
1319
1343
|
yield self.rnd.choice(self.phonemes)
|
|
1320
1344
|
if self.rnd.random() >= prob_add:
|
|
1321
1345
|
break
|
|
1322
1346
|
|
|
1323
|
-
def _random_allo_seq(self, prob_word_add=0.8):
|
|
1347
|
+
def _random_allo_seq(self, prob_word_add: float = 0.8) -> List[AllophoneState]:
|
|
1324
1348
|
allos = []
|
|
1325
1349
|
while True:
|
|
1326
1350
|
phones = self._random_phone_seq()
|
|
@@ -1333,15 +1357,14 @@ class PhoneSeqGenerator:
|
|
|
1333
1357
|
self._allos_set_context(allos)
|
|
1334
1358
|
return list(self._allos_add_states(allos))
|
|
1335
1359
|
|
|
1336
|
-
def generate_garbage_seq(self, target_len):
|
|
1360
|
+
def generate_garbage_seq(self, target_len: int) -> List[AllophoneState]:
|
|
1337
1361
|
"""
|
|
1338
|
-
:param
|
|
1339
|
-
:
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
into a list of allophones in a similar way than generate_seq().
|
|
1362
|
+
:param target_len: len of the returned seq
|
|
1363
|
+
:returns: allophone state list. those will have repetitions etc.
|
|
1364
|
+
It will randomly generate a sequence of phonemes and transform that
|
|
1365
|
+
into a list of allophones in a similar way than generate_seq().
|
|
1343
1366
|
"""
|
|
1344
|
-
allos = []
|
|
1367
|
+
allos: List[AllophoneState] = []
|
|
1345
1368
|
while True:
|
|
1346
1369
|
allos += self._random_allo_seq()
|
|
1347
1370
|
# Add some silence so that left/right context is correct for further allophones.
|
|
@@ -1435,7 +1458,9 @@ class TranslationDataset(CachedDataset2):
|
|
|
1435
1458
|
for prefix in self._main_data_key_map.keys()
|
|
1436
1459
|
if not (prefix == self.target_file_prefix and search_without_reference)
|
|
1437
1460
|
]
|
|
1438
|
-
self._data_files
|
|
1461
|
+
self._data_files: Dict[str, Union[None, BinaryIO, IOBase]] = {
|
|
1462
|
+
prefix: self._get_data_file(prefix) for prefix in self._files_to_read
|
|
1463
|
+
}
|
|
1439
1464
|
|
|
1440
1465
|
self._data_keys = self._source_data_keys + self._target_data_keys
|
|
1441
1466
|
self._data = {data_key: [] for data_key in self._data_keys} # type: typing.Dict[str,typing.List[numpy.ndarray]]
|
|
@@ -1541,11 +1566,10 @@ class TranslationDataset(CachedDataset2):
|
|
|
1541
1566
|
filename = cf(filename)
|
|
1542
1567
|
return filename
|
|
1543
1568
|
|
|
1544
|
-
def _get_data_file(self, prefix):
|
|
1569
|
+
def _get_data_file(self, prefix) -> Union[BinaryIO, IOBase]:
|
|
1545
1570
|
"""
|
|
1546
1571
|
:param str prefix: e.g. "source" or "target"
|
|
1547
1572
|
:return: full filename
|
|
1548
|
-
:rtype: io.FileIO
|
|
1549
1573
|
"""
|
|
1550
1574
|
import os
|
|
1551
1575
|
|
returnn/torch/engine.py
CHANGED
|
@@ -980,6 +980,7 @@ class Engine(EngineBase):
|
|
|
980
980
|
missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
|
|
981
981
|
preload_model_state, strict=False
|
|
982
982
|
)
|
|
983
|
+
preload_model_state_keys = set(preload_model_state.keys())
|
|
983
984
|
loaded_state_keys.update(preload_model_state.keys())
|
|
984
985
|
missing_keys.difference_update(preload_model_state.keys())
|
|
985
986
|
del preload_model_state
|
|
@@ -987,6 +988,11 @@ class Engine(EngineBase):
|
|
|
987
988
|
|
|
988
989
|
if opts.get("prefix", ""):
|
|
989
990
|
prefix_keys = [key for key in self._pt_model.state_dict() if key.startswith(opts.get("prefix", ""))]
|
|
991
|
+
if not prefix_keys:
|
|
992
|
+
raise Exception(
|
|
993
|
+
"No keys with prefix %r found in model.\nModel params:\n%s"
|
|
994
|
+
% (opts.get("prefix", ""), ", ".join(name for name, _ in self._pt_model.named_parameters()))
|
|
995
|
+
)
|
|
990
996
|
else:
|
|
991
997
|
prefix_keys = model_state_keys_set
|
|
992
998
|
missing_keys_preload = (
|
|
@@ -995,6 +1001,12 @@ class Engine(EngineBase):
|
|
|
995
1001
|
unexpected_keys_preload = (
|
|
996
1002
|
set(prefix_keys).intersection(set(unexpected_keys_preload)).difference(loaded_state_keys)
|
|
997
1003
|
)
|
|
1004
|
+
if not preload_model_state_keys.intersection(prefix_keys):
|
|
1005
|
+
raise Exception(
|
|
1006
|
+
f"No keys with prefix {opts.get('prefix', '')!r} found in preload model state.\n"
|
|
1007
|
+
f"Preload model state keys: {preload_model_state_keys}\n"
|
|
1008
|
+
f"Model state keys: {model_state_keys_set}"
|
|
1009
|
+
)
|
|
998
1010
|
if missing_keys_preload and not opts.get("ignore_missing", False):
|
|
999
1011
|
missing_keys.update(missing_keys_preload)
|
|
1000
1012
|
if missing_keys_preload:
|
|
@@ -1077,7 +1089,7 @@ class Engine(EngineBase):
|
|
|
1077
1089
|
get_model_func = self.config.typed_value("get_model")
|
|
1078
1090
|
assert get_model_func, "get_model not defined in config"
|
|
1079
1091
|
sentinel_kw = util.get_fwd_compat_kwargs()
|
|
1080
|
-
model = get_model_func(epoch=epoch, step=step, **sentinel_kw)
|
|
1092
|
+
model = get_model_func(epoch=epoch, step=step, device=self._device, **sentinel_kw)
|
|
1081
1093
|
self._orig_model = model
|
|
1082
1094
|
if isinstance(model, rf.Module):
|
|
1083
1095
|
self._pt_model = rf_module_to_pt_module(model)
|
returnn/torch/frontend/bridge.py
CHANGED
|
@@ -109,18 +109,27 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
109
109
|
self._aux_params_as_buffers = aux_params_as_buffers
|
|
110
110
|
self._is_initializing = True
|
|
111
111
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
112
|
+
for name, value in vars(rf_module).items():
|
|
113
|
+
if isinstance(value, rf.Parameter):
|
|
114
|
+
pt_param = value.raw_tensor
|
|
115
|
+
assert isinstance(pt_param, torch.nn.Parameter)
|
|
116
|
+
if value.auxiliary and aux_params_as_buffers:
|
|
117
|
+
self.register_buffer(name, pt_param)
|
|
118
|
+
else:
|
|
119
|
+
self.register_parameter(name, pt_param)
|
|
120
|
+
|
|
121
|
+
elif isinstance(value, rf.Module):
|
|
122
|
+
pt_mod = rf_module_to_pt_module(value, aux_params_as_buffers=aux_params_as_buffers)
|
|
123
|
+
self.add_module(name, pt_mod)
|
|
124
|
+
|
|
125
|
+
elif isinstance(value, torch.nn.Parameter):
|
|
126
|
+
self.register_parameter(name, value)
|
|
127
|
+
|
|
128
|
+
elif isinstance(value, torch.Tensor): # make sure this check is after torch.nn.Parameter
|
|
129
|
+
self.register_buffer(name, value)
|
|
130
|
+
|
|
131
|
+
elif isinstance(value, torch.nn.Module):
|
|
132
|
+
self.add_module(name, value)
|
|
124
133
|
|
|
125
134
|
self._is_initializing = False
|
|
126
135
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=FNB_wClTNbduk7TOGSMIHj-XEM6fy5ef7mjbIWilbh4,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=Po_jlRBUgzs773OdWo6EfZJpD78f3rXABybtsIWUZWk,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -20,7 +20,7 @@ returnn/datasets/cached2.py,sha256=STojLL2Ivvd0xMfZRlYgzsHKlikYKL-caZCIDCgc_9g,1
|
|
|
20
20
|
returnn/datasets/distrib_files.py,sha256=kyqIQILDPAO2TXr39hjslmDxIAc3pkY1UOoj8nuiFXo,27534
|
|
21
21
|
returnn/datasets/generating.py,sha256=e2-SXcax7xQ4fkVW_Q5MgOLP6KlB7EQXJi_v64gVAWI,99805
|
|
22
22
|
returnn/datasets/hdf.py,sha256=shif0aQqWWNJ0b6YnycpPjIVNsxjLrA41Y66-_SluGI,66993
|
|
23
|
-
returnn/datasets/lm.py,sha256=
|
|
23
|
+
returnn/datasets/lm.py,sha256=h0IHUbze87njKrcD5eT1FRxde7elIio05n-BWiqmjFE,98805
|
|
24
24
|
returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
|
|
25
25
|
returnn/datasets/meta.py,sha256=wHquywF1C7-YWhcSFSAdDNc0nEHRjE-ks7YIEuDFMIE,94731
|
|
26
26
|
returnn/datasets/multi_proc.py,sha256=7kppiXGiel824HM3GvHegluIxtiNAHafm-e6qh6W7YU,21948
|
|
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
|
|
|
207
207
|
returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=i13cUVjI7GxpO0TAresrNyCM0ZBAaf-cXNr09Fmg_2k,6266
|
|
210
|
-
returnn/torch/engine.py,sha256=
|
|
210
|
+
returnn/torch/engine.py,sha256=8BIpdcrpbJL9HrvCX-hISh-14zW9aSrHGvRWT9s0zOk,77103
|
|
211
211
|
returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
|
|
@@ -218,7 +218,7 @@ returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706uki
|
|
|
218
218
|
returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
|
|
219
219
|
returnn/torch/frontend/_backend.py,sha256=h_rUhBPxLRgpZSqX4C8vX8q4dHWMhZpwPmGbKN6MsZo,99995
|
|
220
220
|
returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
|
|
221
|
-
returnn/torch/frontend/bridge.py,sha256=
|
|
221
|
+
returnn/torch/frontend/bridge.py,sha256=Z2_UW8AagezC7zsXDc5PKcd8G9WwisV7j9SWGHU0m4U,7840
|
|
222
222
|
returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
|
|
223
223
|
returnn/torch/optim/README.md,sha256=0iH5FiKb7iDrVK5n8V6yCh4ciCFG2YSbyh7lPneT5ik,360
|
|
224
224
|
returnn/torch/optim/__init__.py,sha256=yxdbnOkXAHzZ_t6cHi6zn5x_DQNlLZJ-KxZByHTIg1U,29
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250211.210150.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250211.210150.dist-info/METADATA,sha256=FNB_wClTNbduk7TOGSMIHj-XEM6fy5ef7mjbIWilbh4,5215
|
|
258
|
+
returnn-1.20250211.210150.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
259
|
+
returnn-1.20250211.210150.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250211.210150.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|