returnn 1.20250902.10950__py3-none-any.whl → 1.20250902.114352__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250902.10950
3
+ Version: 1.20250902.114352
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250902.010950'
2
- long_version = '1.20250902.010950+git.9d5debf'
1
+ version = '1.20250902.114352'
2
+ long_version = '1.20250902.114352+git.87030fa'
@@ -49,6 +49,7 @@ class TransformerDecoder(rf.Module):
49
49
  layer_opts: Optional[Dict[str, Any]] = None,
50
50
  embed_dim: Optional[Dim] = None,
51
51
  share_embedding: bool = None,
52
+ input_embedding: bool = True,
52
53
  input_embedding_scale: float = None,
53
54
  input_dropout: float = None,
54
55
  logits_with_bias: bool = False,
@@ -72,6 +73,7 @@ class TransformerDecoder(rf.Module):
72
73
  :param layer_opts: options for the decoder layer
73
74
  :param embed_dim: if given, will first have an embedding [vocab,embed] and then a linear [embed,model].
74
75
  :param share_embedding:
76
+ :param input_embedding: whether to use input embedding. If False, you must provide input of dimension model_dim.
75
77
  :param input_embedding_scale:
76
78
  :param input_dropout:
77
79
  :param logits_with_bias:
@@ -103,7 +105,7 @@ class TransformerDecoder(rf.Module):
103
105
 
104
106
  # We could make this optional or configurable if we ever need to.
105
107
  # Or maybe you would just have another separate implementation of this module then...
106
- self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim)
108
+ self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim) if input_embedding else None
107
109
 
108
110
  self.input_embedding_proj = None
109
111
  if embed_dim:
@@ -121,21 +123,31 @@ class TransformerDecoder(rf.Module):
121
123
  raise TypeError(f"unexpected pos_enc type {pos_enc!r}")
122
124
  self.pos_enc = pos_enc
123
125
  if share_embedding is None:
124
- if BehaviorVersion.get() < 20:
125
- logging.getLogger("returnn.frontend").warning(
126
- "TransformerDecoder share_embedding default is False"
127
- f" with your behavior version {BehaviorVersion.get()}."
128
- " Explicitly set share_embedding or switch to a new behavior version >= 20."
129
- )
130
- share_embedding = True if BehaviorVersion.get() >= 20 else False
126
+ if embed_dim and embed_dim != model_dim:
127
+ share_embedding = False
128
+ elif input_embedding:
129
+ if BehaviorVersion.get() < 20:
130
+ logging.getLogger("returnn.frontend").warning(
131
+ "TransformerDecoder share_embedding default is False"
132
+ f" with your behavior version {BehaviorVersion.get()}."
133
+ " Explicitly set share_embedding or switch to a new behavior version >= 20."
134
+ )
135
+ share_embedding = True if BehaviorVersion.get() >= 20 else False
136
+ else: # not input_embedding
137
+ share_embedding = False
131
138
  if input_embedding_scale is None:
132
- if BehaviorVersion.get() < 20:
133
- logging.getLogger("returnn.frontend").warning(
134
- "TransformerDecoder input_embedding_scale default is suboptimal"
135
- f" with your behavior version {BehaviorVersion.get()}."
136
- " Explicitly set input_embedding_scale or switch to a new behavior version >= 20."
137
- )
138
- input_embedding_scale = model_dim.dimension**0.5 if BehaviorVersion.get() >= 20 else 1.0
139
+ if input_embedding:
140
+ if BehaviorVersion.get() < 20:
141
+ logging.getLogger("returnn.frontend").warning(
142
+ "TransformerDecoder input_embedding_scale default is suboptimal"
143
+ f" with your behavior version {BehaviorVersion.get()}."
144
+ " Explicitly set input_embedding_scale or switch to a new behavior version >= 20."
145
+ )
146
+ input_embedding_scale = model_dim.dimension**0.5 if BehaviorVersion.get() >= 20 else 1.0
147
+ elif pos_enc:
148
+ input_embedding_scale = model_dim.dimension**0.5
149
+ else:
150
+ input_embedding_scale = 1.0
139
151
  self.input_embedding_scale = input_embedding_scale
140
152
  if input_dropout is None:
141
153
  if dropout > 0 and BehaviorVersion.get() < 20:
@@ -179,7 +191,9 @@ class TransformerDecoder(rf.Module):
179
191
  self.logits = rf.Linear(model_dim, vocab_dim, with_bias=logits_with_bias)
180
192
 
181
193
  if share_embedding:
182
- assert not embed_dim and not logits_with_bias, "not supported together with share_embedding"
194
+ assert input_embedding, "input_embedding=True required for share_embedding"
195
+ assert not embed_dim or embed_dim == model_dim, f"{embed_dim=} not supported with share_embedding"
196
+ assert not logits_with_bias, "logits_with_bias=True expected with share_embedding"
183
197
  self.logits.weight = self.input_embedding.weight
184
198
 
185
199
  def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> rf.State:
@@ -219,7 +233,12 @@ class TransformerDecoder(rf.Module):
219
233
  """
220
234
  new_state = rf.State()
221
235
 
222
- decoded = self.input_embedding(source) * self.input_embedding_scale
236
+ if self.input_embedding is not None:
237
+ decoded = self.input_embedding(source)
238
+ else:
239
+ decoded = source
240
+ if self.input_embedding_scale != 1:
241
+ decoded = decoded * self.input_embedding_scale
223
242
  if self.pos_enc is not None:
224
243
  decoded = decoded + self.pos_enc(spatial_dim=spatial_dim, offset=state.pos)
225
244
  decoded = rf.dropout(decoded, self.input_dropout)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250902.10950
3
+ Version: 1.20250902.114352
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=GVal7eVN_obo9mfdhPK2WvH2MzSm51cFZJChHEsF2XU,5214
1
+ returnn/PKG-INFO,sha256=zCN-KDwCaMFI82phyc-dsc6Fo_thXN-UOBfvd93s0bU,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=jTlsQFAqLqFgm0UJ0uWltcnLf69QwqOK0yV4Slt-2Is,77
6
+ returnn/_setup_info_generated.py,sha256=J31pQBS08nmbv7yxX4hOOWq1d__odaj7aX-8_sTiVXo,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -135,7 +135,7 @@ returnn/frontend/conversions/espnet_e_branchformer.py,sha256=Mmp3G6nySy0CqeHa-um
135
135
  returnn/frontend/conversions/hf_llama.py,sha256=1WQOhQyUWwkAznaRqK2zpThP8XZbaomkaE8qMG_bZPY,9662
136
136
  returnn/frontend/conversions/torch_nn.py,sha256=WAq_hs1tb5OC4iGmVemXvo3qba_e1MJXxRzG9pNK2HI,2204
137
137
  returnn/frontend/decoder/__init__.py,sha256=A-koKyPVlXp_V_2bk6GKZ1Xfv4rYIcfxGMXQHkHZiOQ,41
138
- returnn/frontend/decoder/transformer.py,sha256=20a37hMiPbQBHx3tSbOeiAbFPVRcX_KYpPuw8tmY6GU,23658
138
+ returnn/frontend/decoder/transformer.py,sha256=64Z1IY_WcDuj8Ti73BGwbT_grrEpxBl5mIsBZkqJzHQ,24650
139
139
  returnn/frontend/encoder/__init__.py,sha256=0QGLlujRIKx3zBREeShza_-xhGIxj73zbd7t-g1m-ho,17
140
140
  returnn/frontend/encoder/base.py,sha256=A759EwCYAmSi-kzXz1vaTjR2l59TvNGQlzaNdp3UOKs,2109
141
141
  returnn/frontend/encoder/conformer.py,sha256=rWulygolesbYkLw9naSxwygaZhWqKpHKEVj-1AQbel0,21351
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250902.10950.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250902.10950.dist-info/METADATA,sha256=GVal7eVN_obo9mfdhPK2WvH2MzSm51cFZJChHEsF2XU,5214
258
- returnn-1.20250902.10950.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250902.10950.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250902.10950.dist-info/RECORD,,
256
+ returnn-1.20250902.114352.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250902.114352.dist-info/METADATA,sha256=zCN-KDwCaMFI82phyc-dsc6Fo_thXN-UOBfvd93s0bU,5215
258
+ returnn-1.20250902.114352.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250902.114352.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250902.114352.dist-info/RECORD,,