returnn 1.20250826.115240__py3-none-any.whl → 1.20250828.2732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250826.115240
3
+ Version: 1.20250828.2732
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250826.115240'
2
- long_version = '1.20250826.115240+git.2baa52f'
1
+ version = '1.20250828.002732'
2
+ long_version = '1.20250828.002732+git.06c221e'
@@ -429,11 +429,27 @@ def concat(
429
429
  assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True"
430
430
  if not out_dim:
431
431
  out_dim = sum(d for _, d in sources)
432
+ # noinspection PyProtectedMember
433
+ out = sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim)
432
434
  if handle_dynamic_dims is None or handle_dynamic_dims:
435
+ need_to_handle = False
433
436
  for src, dim in sources[:-1]:
434
- assert dim.is_static(), f"concat {sources}, dim {dim} is not static, not yet implemented..."
435
- # noinspection PyProtectedMember
436
- return sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim), out_dim
437
+ if dim.need_masking():
438
+ need_to_handle = True
439
+ if need_to_handle:
440
+ masks = []
441
+ for _, dim in sources:
442
+ masks.append(
443
+ dim.get_mask(dim_order=(dim,) + dim.dyn_size_ext.dims, device=out.device)
444
+ if dim.need_masking()
445
+ else rf.constant(True, dims=[dim], device=out.device)
446
+ )
447
+ # noinspection PyProtectedMember
448
+ mask_concat = sources[0][0]._raw_backend.concat(
449
+ *[(mask, dim) for (_, dim), mask in zip(sources, masks)], allow_broadcast=True, out_dim=out_dim
450
+ )
451
+ out, out_dim = rf.masked_select(out, mask=mask_concat, dims=[out_dim])
452
+ return out, out_dim
437
453
 
438
454
 
439
455
  def concat_features(*sources: Tensor, allow_broadcast=False) -> Tensor:
@@ -8,6 +8,8 @@ https://github.com/rwth-i6/returnn_common/issues/233
8
8
 
9
9
  from __future__ import annotations
10
10
  from typing import Optional, Union, Any, Tuple, List, Dict, Callable
11
+ from types import FunctionType
12
+ import functools
11
13
  import copy as _copy
12
14
  from returnn.tensor import Tensor, Dim
13
15
  import returnn.frontend as rf
@@ -298,7 +300,8 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
298
300
  *,
299
301
  num_layers: int,
300
302
  input_layer: Optional[Union[ConformerConvSubsample, ISeqDownsamplingEncoder, rf.Module, Any]],
301
- input_embedding_scale: float = 1.0,
303
+ input_embedding_scale: Optional[float] = None,
304
+ pos_enc: Union[None, Callable, Dict[str, Any], rf.Module] = None,
302
305
  input_dropout: float = 0.1,
303
306
  ff_dim: Dim = NotSpecified,
304
307
  ff_activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] = NotSpecified,
@@ -317,8 +320,17 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
317
320
  :param num_layers: the number of encoder layers
318
321
  :param input_layer: input/frontend/prenet with potential subsampling.
319
322
  (x, in_spatial_dim) -> (y, out_spatial_dim)
320
- :param input_embedding_scale: applied after input_layer. 1.0 by default for historic reasons.
321
- In std Transformer, also ESPnet E-Branchformer and Conformer, this is sqrt(out_dim).
323
+ :param input_embedding_scale: applied after input_layer.
324
+ 1.0 by default for historic reasons if pos_enc is None,
325
+ else sqrt(out_dim) by default.
326
+ In std Transformer, also ESPnet E-Branchformer and Conformer, this is sqrt(out_dim),
327
+ which is relevant when you add positional encoding.
328
+ :param pos_enc: positional encoding, applied after input_embedding_scale.
329
+ None (no positional encoding) by default, unlike standard Transformer.
330
+ E.g. :func:`rf.sinusoidal_positional_encoding` for absolute pos enc.
331
+ Note, relative positional encoding is usually part of the attention layer,
332
+ e.g. :class:`rf.RelPosSelfAttention`,
333
+ and nothing needs to be set here.
322
334
  :param input_dropout: applied after input_projection(input_layer(x))
323
335
  :param ff_dim: the dimension of feed-forward layers. 2048 originally, or 4 times out_dim
324
336
  :param ff_activation: activation function for feed-forward network
@@ -352,12 +364,22 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
352
364
  else:
353
365
  raise TypeError(f"unexpected input_layer {input_layer!r}")
354
366
  self.input_layer = input_layer
355
- self.input_projection = (
356
- rf.Linear(self.input_layer.out_dim if self.input_layer else self.in_dim, self.out_dim, with_bias=False)
357
- if input_layer
358
- else None
359
- )
367
+ in_dim = self.input_layer.out_dim if self.input_layer else self.in_dim
368
+ self.input_projection = rf.Linear(in_dim, self.out_dim, with_bias=False) if in_dim != self.out_dim else None
369
+ if input_embedding_scale is None:
370
+ input_embedding_scale = (self.out_dim.dimension**0.5) if pos_enc is not None else 1.0
360
371
  self.input_embedding_scale = input_embedding_scale
372
+ if pos_enc is None:
373
+ pass
374
+ elif isinstance(pos_enc, dict):
375
+ pos_enc = rf.build_from_dict(pos_enc, feat_dim=self.out_dim)
376
+ elif isinstance(pos_enc, rf.Module):
377
+ pass
378
+ elif isinstance(pos_enc, FunctionType):
379
+ pos_enc = functools.partial(pos_enc, feat_dim=self.out_dim)
380
+ else:
381
+ raise TypeError(f"unexpected pos_enc type {pos_enc!r}")
382
+ self.pos_enc = pos_enc
361
383
  self.input_dropout = input_dropout
362
384
 
363
385
  if not encoder_layer or isinstance(encoder_layer, (dict, type)):
@@ -411,6 +433,8 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
411
433
  x = self.input_projection(x_subsample) if self.input_projection else x_subsample
412
434
  if self.input_embedding_scale != 1.0:
413
435
  x = x * self.input_embedding_scale
436
+ if self.pos_enc is not None:
437
+ x = x + self.pos_enc(spatial_dim=out_spatial_dim)
414
438
  x = rf.dropout(x, self.input_dropout, axis=self.dropout_broadcast and self.out_dim)
415
439
  x = self.layers(x, spatial_dim=out_spatial_dim, collected_outputs=collected_outputs)
416
440
  return x, out_spatial_dim
returnn/torch/updater.py CHANGED
@@ -440,8 +440,6 @@ class Updater:
440
440
  If it's a str, it must be the optimizer name.
441
441
  :return: tuple (optimizer, optional optimizer_param_groups_extra_opts).
442
442
  """
443
- lr = self.learning_rate
444
-
445
443
  # If the parameter is already a valid optimizer, return it without further processing
446
444
  if isinstance(optimizer_opts, torch.optim.Optimizer):
447
445
  return optimizer_opts, None
@@ -467,8 +465,9 @@ class Updater:
467
465
  opt_kwargs["eps"] = opt_kwargs.pop("epsilon")
468
466
  if "learning_rate" in opt_kwargs or "lr" in opt_kwargs:
469
467
  raise ValueError("'learning_rate' should be set outside of the 'optimizer' dict.")
470
- lr = lr * opt_kwargs.pop("learning_rate_multiplier", 1.0)
471
- opt_kwargs["lr"] = lr
468
+ # lr will anyway be updated in set_current_train_step / _update_effective_learning_rate,
469
+ # so this value doesn't really matter here
470
+ opt_kwargs["lr"] = self.learning_rate
472
471
 
473
472
  param_groups = self._get_optimizer_param_groups(optim_class, opt_kwargs)
474
473
  param_groups = list(param_groups)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250826.115240
3
+ Version: 1.20250828.2732
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=FML_TBBHEcTCM4--oW_2wEb8Yk4ybXtAH6CQ7tSums0,5215
1
+ returnn/PKG-INFO,sha256=40gq368ieKKYuTEYmkvPOiakDG6N6_RPtPd5pBGCoQ0,5213
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=BCJ_PJr520UrgVLYRMgDdj07KiDWctntLtC9aS8cTH0,77
6
+ returnn/_setup_info_generated.py,sha256=b0sPsBSCJObrNcEG_PgvBTHyCMvourjvAfvgeL7uEAk,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,84
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
83
- returnn/frontend/array_.py,sha256=o_NSq87pB5I2XvFUjk40Dobqx6tTfEY1wzgmaelujgM,51511
83
+ returnn/frontend/array_.py,sha256=5aCU-BCBH035QZlpYSRGtVXhxz78tteZ75e57FxCIRw,52182
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -138,7 +138,7 @@ returnn/frontend/decoder/__init__.py,sha256=A-koKyPVlXp_V_2bk6GKZ1Xfv4rYIcfxGMXQ
138
138
  returnn/frontend/decoder/transformer.py,sha256=20a37hMiPbQBHx3tSbOeiAbFPVRcX_KYpPuw8tmY6GU,23658
139
139
  returnn/frontend/encoder/__init__.py,sha256=0QGLlujRIKx3zBREeShza_-xhGIxj73zbd7t-g1m-ho,17
140
140
  returnn/frontend/encoder/base.py,sha256=A759EwCYAmSi-kzXz1vaTjR2l59TvNGQlzaNdp3UOKs,2109
141
- returnn/frontend/encoder/conformer.py,sha256=ro0uzEzDbAyNGYN5ff0KmiDl4HOYQluu64mJxYzuy-M,19972
141
+ returnn/frontend/encoder/conformer.py,sha256=rWulygolesbYkLw9naSxwygaZhWqKpHKEVj-1AQbel0,21351
142
142
  returnn/frontend/encoder/conformer_v2.py,sha256=vAYdT8m2Zzg3IIZZafeccClFHU1_c9T-EgBOsHadQPA,7701
143
143
  returnn/frontend/encoder/e_branchformer.py,sha256=SZdhpb90FaQdpzgvSOtFPLbLCa0NdycbB5Z4vMoY4TM,12279
144
144
  returnn/frontend/encoder/transformer.py,sha256=Jj0mF1D2MohOk-9sGYdsLtVW_86fwoq4pKWCdPMvPR8,11580
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
210
210
  returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
211
- returnn/torch/updater.py,sha256=7lMoA01Yzp18MY5jjIFncsajTjOD713pK38nU6r-jiE,31999
211
+ returnn/torch/updater.py,sha256=nNd1mBPQyvIB096BEFi0KKmRI-U3jnRETzb743p2B9c,32064
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
214
214
  returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250826.115240.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250826.115240.dist-info/METADATA,sha256=FML_TBBHEcTCM4--oW_2wEb8Yk4ybXtAH6CQ7tSums0,5215
258
- returnn-1.20250826.115240.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250826.115240.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250826.115240.dist-info/RECORD,,
256
+ returnn-1.20250828.2732.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250828.2732.dist-info/METADATA,sha256=40gq368ieKKYuTEYmkvPOiakDG6N6_RPtPd5pBGCoQ0,5213
258
+ returnn-1.20250828.2732.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250828.2732.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250828.2732.dist-info/RECORD,,