returnn 1.20250719.212120__py3-none-any.whl → 1.20250724.195711__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/basic.py +6 -8
- returnn/datasets/cached2.py +2 -10
- returnn/datasets/lm.py +0 -2
- returnn/frontend/decoder/transformer.py +9 -2
- {returnn-1.20250719.212120.dist-info → returnn-1.20250724.195711.dist-info}/METADATA +1 -1
- {returnn-1.20250719.212120.dist-info → returnn-1.20250724.195711.dist-info}/RECORD +11 -11
- {returnn-1.20250719.212120.dist-info → returnn-1.20250724.195711.dist-info}/LICENSE +0 -0
- {returnn-1.20250719.212120.dist-info → returnn-1.20250724.195711.dist-info}/WHEEL +0 -0
- {returnn-1.20250719.212120.dist-info → returnn-1.20250724.195711.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250724.195711'
|
|
2
|
+
long_version = '1.20250724.195711+git.e0cf62f'
|
returnn/datasets/basic.py
CHANGED
|
@@ -19,7 +19,6 @@ import os
|
|
|
19
19
|
import math
|
|
20
20
|
import numpy
|
|
21
21
|
import functools
|
|
22
|
-
import typing
|
|
23
22
|
from typing import TYPE_CHECKING, Optional, Any, Set, Tuple, Union, Type, Dict, Sequence, List, Callable
|
|
24
23
|
|
|
25
24
|
from returnn.log import log
|
|
@@ -428,9 +427,9 @@ class Dataset:
|
|
|
428
427
|
"""
|
|
429
428
|
raise OptionalNotImplementedError
|
|
430
429
|
|
|
431
|
-
def get_num_timesteps(self):
|
|
430
|
+
def get_num_timesteps(self) -> Union[int, NumbersDict]:
|
|
432
431
|
"""
|
|
433
|
-
:
|
|
432
|
+
:return: how much frames we have in total.
|
|
434
433
|
"""
|
|
435
434
|
assert self._num_timesteps > 0
|
|
436
435
|
return self._num_timesteps
|
|
@@ -559,7 +558,7 @@ class Dataset:
|
|
|
559
558
|
for i in range(1, num):
|
|
560
559
|
seq_index[i::num] += i * (num_seqs // num)
|
|
561
560
|
elif seq_ordering_method == "reverse":
|
|
562
|
-
seq_index = range(num_seqs - 1, -1, -1) # type: Union[range,
|
|
561
|
+
seq_index = range(num_seqs - 1, -1, -1) # type: Union[range, Sequence[int]]
|
|
563
562
|
elif seq_ordering_method in ["sorted", "sorted_reverse"]:
|
|
564
563
|
assert get_seq_len
|
|
565
564
|
reverse = -1 if seq_ordering_method == "sorted_reverse" else 1
|
|
@@ -748,12 +747,11 @@ class Dataset:
|
|
|
748
747
|
"""
|
|
749
748
|
self.epoch = None
|
|
750
749
|
|
|
751
|
-
def get_current_seq_order(self):
|
|
750
|
+
def get_current_seq_order(self) -> Sequence[int]:
|
|
752
751
|
"""
|
|
753
752
|
:return: many datasets use self.get_seq_order_for_epoch. this function would return the current seq order
|
|
754
753
|
for the current epoch, after self.init_seq_order was called.
|
|
755
754
|
Not all datasets implement this.
|
|
756
|
-
:rtype: typing.Sequence[int]
|
|
757
755
|
"""
|
|
758
756
|
raise OptionalNotImplementedError
|
|
759
757
|
|
|
@@ -902,7 +900,7 @@ class Dataset:
|
|
|
902
900
|
if self.seq_ordering == "default" and self.partition_epoch == 1:
|
|
903
901
|
return seq_idx
|
|
904
902
|
assert self.have_corpus_seq_idx()
|
|
905
|
-
raise
|
|
903
|
+
raise NotImplementedError
|
|
906
904
|
|
|
907
905
|
def have_get_corpus_seq(self) -> bool:
|
|
908
906
|
"""
|
|
@@ -1061,7 +1059,7 @@ class Dataset:
|
|
|
1061
1059
|
if key in self.num_outputs:
|
|
1062
1060
|
if self.num_outputs[key][1] <= 1:
|
|
1063
1061
|
return []
|
|
1064
|
-
res_shape = [None] * (self.num_outputs[key][1] - 1)
|
|
1062
|
+
res_shape: List[Union[None, int]] = [None] * (self.num_outputs[key][1] - 1)
|
|
1065
1063
|
if not self.is_data_sparse(key):
|
|
1066
1064
|
res_shape[-1] = self.get_data_dim(key)
|
|
1067
1065
|
return res_shape
|
returnn/datasets/cached2.py
CHANGED
|
@@ -4,18 +4,10 @@ Provides :class:`CachedDataset2`.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
import numpy
|
|
7
|
-
import
|
|
8
|
-
from typing import Optional
|
|
7
|
+
from typing import Optional, List
|
|
9
8
|
from threading import Condition
|
|
10
9
|
from .basic import Dataset, DatasetSeq
|
|
11
10
|
|
|
12
|
-
try:
|
|
13
|
-
# noinspection PyCompatibility
|
|
14
|
-
from _thread import interrupt_main
|
|
15
|
-
except ImportError:
|
|
16
|
-
# noinspection PyUnresolvedReferences,PyCompatibility
|
|
17
|
-
from thread import interrupt_main
|
|
18
|
-
|
|
19
11
|
|
|
20
12
|
class CachedDataset2(Dataset):
|
|
21
13
|
"""
|
|
@@ -36,7 +28,7 @@ class CachedDataset2(Dataset):
|
|
|
36
28
|
self._num_timesteps = None
|
|
37
29
|
self.epoch = None
|
|
38
30
|
self.reached_final_seq = False
|
|
39
|
-
self.added_data
|
|
31
|
+
self.added_data: List[DatasetSeq] = []
|
|
40
32
|
self.expected_load_seq_start = 0
|
|
41
33
|
self._num_timesteps_accumulated = 0
|
|
42
34
|
|
returnn/datasets/lm.py
CHANGED
|
@@ -24,7 +24,6 @@ from typing import (
|
|
|
24
24
|
cast,
|
|
25
25
|
Generator,
|
|
26
26
|
)
|
|
27
|
-
import typing
|
|
28
27
|
import os
|
|
29
28
|
from io import IOBase
|
|
30
29
|
import sys
|
|
@@ -1563,7 +1562,6 @@ class TranslationDataset(CachedDataset2):
|
|
|
1563
1562
|
import returnn.util.better_exchook
|
|
1564
1563
|
|
|
1565
1564
|
returnn.util.better_exchook.install()
|
|
1566
|
-
from returnn.util.basic import AsyncThreadRun
|
|
1567
1565
|
|
|
1568
1566
|
# First iterate once over the data to get the data len as fast as possible.
|
|
1569
1567
|
data_len = 0
|
|
@@ -268,6 +268,7 @@ class TransformerDecoderLayer(rf.Module):
|
|
|
268
268
|
] = None,
|
|
269
269
|
self_att_opts: Optional[Dict[str, Any]] = None,
|
|
270
270
|
att_dropout: float = 0.1,
|
|
271
|
+
cross_att: Optional[Dict[str, Any]] = None,
|
|
271
272
|
norm: Union[type, Dict[str, Any], rf.Module, Callable] = rf.LayerNorm,
|
|
272
273
|
):
|
|
273
274
|
"""
|
|
@@ -333,10 +334,10 @@ class TransformerDecoderLayer(rf.Module):
|
|
|
333
334
|
raise TypeError(f"unexpected self_att type {self_att!r}")
|
|
334
335
|
self.self_att_layer_norm = make_norm(norm, out_dim)
|
|
335
336
|
|
|
336
|
-
self.cross_att = None
|
|
337
|
+
self.cross_att: Optional[rf.CrossAttention] = None # type might be inaccurate, but we expect this interface
|
|
337
338
|
self.cross_att_layer_norm = None
|
|
338
339
|
if encoder_dim is not None:
|
|
339
|
-
|
|
340
|
+
cross_att_opts = dict(
|
|
340
341
|
encoder_dim=self.encoder_dim,
|
|
341
342
|
query_in_dim=out_dim,
|
|
342
343
|
proj_dim=out_dim,
|
|
@@ -345,6 +346,12 @@ class TransformerDecoderLayer(rf.Module):
|
|
|
345
346
|
num_heads=num_heads,
|
|
346
347
|
att_dropout=att_dropout,
|
|
347
348
|
)
|
|
349
|
+
if cross_att is None:
|
|
350
|
+
self.cross_att = rf.CrossAttention(**cross_att_opts)
|
|
351
|
+
elif isinstance(cross_att, dict):
|
|
352
|
+
self.cross_att: Optional[rf.CrossAttention] = rf.build_from_dict(cross_att, **cross_att_opts)
|
|
353
|
+
else:
|
|
354
|
+
raise TypeError(f"unexpected cross_att type {cross_att!r}")
|
|
348
355
|
self.cross_att_layer_norm = make_norm(norm, out_dim)
|
|
349
356
|
|
|
350
357
|
def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> rf.State:
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=fdsXa-aGlgSS206_rjbOuzyZOYFra39xWql2ejHX1ic,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=xJWrWKUD6b2uMW_UOzZMqg4UjzzWUtDTGUnDvJiHAOI,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -13,14 +13,14 @@ returnn/native_op.py,sha256=4_NnvfNxsM8GE_FsD6yOg6PZegqIdtJ3Sl1GdBWmFvg,244424
|
|
|
13
13
|
returnn/pretrain.py,sha256=MHiXJZqkQFmDVyaYsGpd_Acv20wxl7Pr6s6qJzAT2FI,22648
|
|
14
14
|
returnn/datasets/__init__.py,sha256=PvDlfDOaaopIeUIt0OSvHD2eHZkdkyE-sjMXf35EH5U,390
|
|
15
15
|
returnn/datasets/audio.py,sha256=Gmj7a08dnvYh7Z-G1TNapz42L50AIcDE9JeIZaO1s1M,23334
|
|
16
|
-
returnn/datasets/basic.py,sha256=
|
|
16
|
+
returnn/datasets/basic.py,sha256=IJhytVPiQZi7BD8-JVziKKT__PE528FwLQmbeiVQHzc,72303
|
|
17
17
|
returnn/datasets/bundle_file.py,sha256=KQNrS1MSf-4_idlK0c0KFwON-f5sEK0sWU15WpoMYpE,2380
|
|
18
18
|
returnn/datasets/cached.py,sha256=RyefRjSDdp-HveK-2vLy2C6BIHcpqQ_lNvUKlIa4QAI,25412
|
|
19
|
-
returnn/datasets/cached2.py,sha256=
|
|
19
|
+
returnn/datasets/cached2.py,sha256=oJOq2lWRQpxm6kyUKW1w5qZBd4kdKEpwM7KY_QnXbq4,11922
|
|
20
20
|
returnn/datasets/distrib_files.py,sha256=SJ2YkZEZmG9lu3MLTwSMyVNfsXzRHqbLNjUn9IDwVJM,30194
|
|
21
21
|
returnn/datasets/generating.py,sha256=9U_w6URIrv-Rb-hDbPOzYW9qYXzJbw32N6G268IKyoM,99833
|
|
22
22
|
returnn/datasets/hdf.py,sha256=v5sjBenURR9Z-g7AQ9tsL84yDSye5RtbLpym3M6HSDE,67833
|
|
23
|
-
returnn/datasets/lm.py,sha256=
|
|
23
|
+
returnn/datasets/lm.py,sha256=rQ3jV43lSnlGkKu7m5jTTH7aK0BOMXQocsHfJ8OGec8,99950
|
|
24
24
|
returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
|
|
25
25
|
returnn/datasets/meta.py,sha256=6XPPxhiNSxWw9Hu5Z6wG8dD9Zk82FqiI-k9HGQSTKgw,95658
|
|
26
26
|
returnn/datasets/multi_proc.py,sha256=aVjsLt2qjHnHOrEYCgIPCwNYE-f1fiGP6eZ8NGAr3A4,22583
|
|
@@ -135,7 +135,7 @@ returnn/frontend/conversions/espnet_e_branchformer.py,sha256=Mmp3G6nySy0CqeHa-um
|
|
|
135
135
|
returnn/frontend/conversions/hf_llama.py,sha256=1WQOhQyUWwkAznaRqK2zpThP8XZbaomkaE8qMG_bZPY,9662
|
|
136
136
|
returnn/frontend/conversions/torch_nn.py,sha256=WAq_hs1tb5OC4iGmVemXvo3qba_e1MJXxRzG9pNK2HI,2204
|
|
137
137
|
returnn/frontend/decoder/__init__.py,sha256=A-koKyPVlXp_V_2bk6GKZ1Xfv4rYIcfxGMXQHkHZiOQ,41
|
|
138
|
-
returnn/frontend/decoder/transformer.py,sha256=
|
|
138
|
+
returnn/frontend/decoder/transformer.py,sha256=20a37hMiPbQBHx3tSbOeiAbFPVRcX_KYpPuw8tmY6GU,23658
|
|
139
139
|
returnn/frontend/encoder/__init__.py,sha256=0QGLlujRIKx3zBREeShza_-xhGIxj73zbd7t-g1m-ho,17
|
|
140
140
|
returnn/frontend/encoder/base.py,sha256=A759EwCYAmSi-kzXz1vaTjR2l59TvNGQlzaNdp3UOKs,2109
|
|
141
141
|
returnn/frontend/encoder/conformer.py,sha256=ro0uzEzDbAyNGYN5ff0KmiDl4HOYQluu64mJxYzuy-M,19972
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250724.195711.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250724.195711.dist-info/METADATA,sha256=fdsXa-aGlgSS206_rjbOuzyZOYFra39xWql2ejHX1ic,5215
|
|
258
|
+
returnn-1.20250724.195711.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250724.195711.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250724.195711.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|