returnn 1.20250902.10950__py3-none-any.whl → 1.20250902.114352__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/decoder/transformer.py +36 -17
- {returnn-1.20250902.10950.dist-info → returnn-1.20250902.114352.dist-info}/METADATA +1 -1
- {returnn-1.20250902.10950.dist-info → returnn-1.20250902.114352.dist-info}/RECORD +8 -8
- {returnn-1.20250902.10950.dist-info → returnn-1.20250902.114352.dist-info}/LICENSE +0 -0
- {returnn-1.20250902.10950.dist-info → returnn-1.20250902.114352.dist-info}/WHEEL +0 -0
- {returnn-1.20250902.10950.dist-info → returnn-1.20250902.114352.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.20250902.
|
|
2
|
-
long_version = '1.20250902.
|
|
1
|
+
version = '1.20250902.114352'
|
|
2
|
+
long_version = '1.20250902.114352+git.87030fa'
|
|
@@ -49,6 +49,7 @@ class TransformerDecoder(rf.Module):
|
|
|
49
49
|
layer_opts: Optional[Dict[str, Any]] = None,
|
|
50
50
|
embed_dim: Optional[Dim] = None,
|
|
51
51
|
share_embedding: bool = None,
|
|
52
|
+
input_embedding: bool = True,
|
|
52
53
|
input_embedding_scale: float = None,
|
|
53
54
|
input_dropout: float = None,
|
|
54
55
|
logits_with_bias: bool = False,
|
|
@@ -72,6 +73,7 @@ class TransformerDecoder(rf.Module):
|
|
|
72
73
|
:param layer_opts: options for the decoder layer
|
|
73
74
|
:param embed_dim: if given, will first have an embedding [vocab,embed] and then a linear [embed,model].
|
|
74
75
|
:param share_embedding:
|
|
76
|
+
:param input_embedding: whether to use input embedding. If False, you must provide input of dimension model_dim.
|
|
75
77
|
:param input_embedding_scale:
|
|
76
78
|
:param input_dropout:
|
|
77
79
|
:param logits_with_bias:
|
|
@@ -103,7 +105,7 @@ class TransformerDecoder(rf.Module):
|
|
|
103
105
|
|
|
104
106
|
# We could make this optional or configurable if we ever need to.
|
|
105
107
|
# Or maybe you would just have another separate implementation of this module then...
|
|
106
|
-
self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim)
|
|
108
|
+
self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim) if input_embedding else None
|
|
107
109
|
|
|
108
110
|
self.input_embedding_proj = None
|
|
109
111
|
if embed_dim:
|
|
@@ -121,21 +123,31 @@ class TransformerDecoder(rf.Module):
|
|
|
121
123
|
raise TypeError(f"unexpected pos_enc type {pos_enc!r}")
|
|
122
124
|
self.pos_enc = pos_enc
|
|
123
125
|
if share_embedding is None:
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
"
|
|
129
|
-
|
|
130
|
-
|
|
126
|
+
if embed_dim and embed_dim != model_dim:
|
|
127
|
+
share_embedding = False
|
|
128
|
+
elif input_embedding:
|
|
129
|
+
if BehaviorVersion.get() < 20:
|
|
130
|
+
logging.getLogger("returnn.frontend").warning(
|
|
131
|
+
"TransformerDecoder share_embedding default is False"
|
|
132
|
+
f" with your behavior version {BehaviorVersion.get()}."
|
|
133
|
+
" Explicitly set share_embedding or switch to a new behavior version >= 20."
|
|
134
|
+
)
|
|
135
|
+
share_embedding = True if BehaviorVersion.get() >= 20 else False
|
|
136
|
+
else: # not input_embedding
|
|
137
|
+
share_embedding = False
|
|
131
138
|
if input_embedding_scale is None:
|
|
132
|
-
if
|
|
133
|
-
|
|
134
|
-
"
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
+
if input_embedding:
|
|
140
|
+
if BehaviorVersion.get() < 20:
|
|
141
|
+
logging.getLogger("returnn.frontend").warning(
|
|
142
|
+
"TransformerDecoder input_embedding_scale default is suboptimal"
|
|
143
|
+
f" with your behavior version {BehaviorVersion.get()}."
|
|
144
|
+
" Explicitly set input_embedding_scale or switch to a new behavior version >= 20."
|
|
145
|
+
)
|
|
146
|
+
input_embedding_scale = model_dim.dimension**0.5 if BehaviorVersion.get() >= 20 else 1.0
|
|
147
|
+
elif pos_enc:
|
|
148
|
+
input_embedding_scale = model_dim.dimension**0.5
|
|
149
|
+
else:
|
|
150
|
+
input_embedding_scale = 1.0
|
|
139
151
|
self.input_embedding_scale = input_embedding_scale
|
|
140
152
|
if input_dropout is None:
|
|
141
153
|
if dropout > 0 and BehaviorVersion.get() < 20:
|
|
@@ -179,7 +191,9 @@ class TransformerDecoder(rf.Module):
|
|
|
179
191
|
self.logits = rf.Linear(model_dim, vocab_dim, with_bias=logits_with_bias)
|
|
180
192
|
|
|
181
193
|
if share_embedding:
|
|
182
|
-
assert
|
|
194
|
+
assert input_embedding, "input_embedding=True required for share_embedding"
|
|
195
|
+
assert not embed_dim or embed_dim == model_dim, f"{embed_dim=} not supported with share_embedding"
|
|
196
|
+
assert not logits_with_bias, "logits_with_bias=True expected with share_embedding"
|
|
183
197
|
self.logits.weight = self.input_embedding.weight
|
|
184
198
|
|
|
185
199
|
def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> rf.State:
|
|
@@ -219,7 +233,12 @@ class TransformerDecoder(rf.Module):
|
|
|
219
233
|
"""
|
|
220
234
|
new_state = rf.State()
|
|
221
235
|
|
|
222
|
-
|
|
236
|
+
if self.input_embedding is not None:
|
|
237
|
+
decoded = self.input_embedding(source)
|
|
238
|
+
else:
|
|
239
|
+
decoded = source
|
|
240
|
+
if self.input_embedding_scale != 1:
|
|
241
|
+
decoded = decoded * self.input_embedding_scale
|
|
223
242
|
if self.pos_enc is not None:
|
|
224
243
|
decoded = decoded + self.pos_enc(spatial_dim=spatial_dim, offset=state.pos)
|
|
225
244
|
decoded = rf.dropout(decoded, self.input_dropout)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=zCN-KDwCaMFI82phyc-dsc6Fo_thXN-UOBfvd93s0bU,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=J31pQBS08nmbv7yxX4hOOWq1d__odaj7aX-8_sTiVXo,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -135,7 +135,7 @@ returnn/frontend/conversions/espnet_e_branchformer.py,sha256=Mmp3G6nySy0CqeHa-um
|
|
|
135
135
|
returnn/frontend/conversions/hf_llama.py,sha256=1WQOhQyUWwkAznaRqK2zpThP8XZbaomkaE8qMG_bZPY,9662
|
|
136
136
|
returnn/frontend/conversions/torch_nn.py,sha256=WAq_hs1tb5OC4iGmVemXvo3qba_e1MJXxRzG9pNK2HI,2204
|
|
137
137
|
returnn/frontend/decoder/__init__.py,sha256=A-koKyPVlXp_V_2bk6GKZ1Xfv4rYIcfxGMXQHkHZiOQ,41
|
|
138
|
-
returnn/frontend/decoder/transformer.py,sha256=
|
|
138
|
+
returnn/frontend/decoder/transformer.py,sha256=64Z1IY_WcDuj8Ti73BGwbT_grrEpxBl5mIsBZkqJzHQ,24650
|
|
139
139
|
returnn/frontend/encoder/__init__.py,sha256=0QGLlujRIKx3zBREeShza_-xhGIxj73zbd7t-g1m-ho,17
|
|
140
140
|
returnn/frontend/encoder/base.py,sha256=A759EwCYAmSi-kzXz1vaTjR2l59TvNGQlzaNdp3UOKs,2109
|
|
141
141
|
returnn/frontend/encoder/conformer.py,sha256=rWulygolesbYkLw9naSxwygaZhWqKpHKEVj-1AQbel0,21351
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.20250902.
|
|
257
|
-
returnn-1.20250902.
|
|
258
|
-
returnn-1.20250902.
|
|
259
|
-
returnn-1.20250902.
|
|
260
|
-
returnn-1.20250902.
|
|
256
|
+
returnn-1.20250902.114352.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250902.114352.dist-info/METADATA,sha256=zCN-KDwCaMFI82phyc-dsc6Fo_thXN-UOBfvd93s0bU,5215
|
|
258
|
+
returnn-1.20250902.114352.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250902.114352.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250902.114352.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|