returnn 1.20250826.115240__py3-none-any.whl → 1.20250828.2732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/array_.py +19 -3
- returnn/frontend/encoder/conformer.py +32 -8
- returnn/torch/updater.py +3 -4
- {returnn-1.20250826.115240.dist-info → returnn-1.20250828.2732.dist-info}/METADATA +1 -1
- {returnn-1.20250826.115240.dist-info → returnn-1.20250828.2732.dist-info}/RECORD +10 -10
- {returnn-1.20250826.115240.dist-info → returnn-1.20250828.2732.dist-info}/LICENSE +0 -0
- {returnn-1.20250826.115240.dist-info → returnn-1.20250828.2732.dist-info}/WHEEL +0 -0
- {returnn-1.20250826.115240.dist-info → returnn-1.20250828.2732.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250828.002732'
|
|
2
|
+
long_version = '1.20250828.002732+git.06c221e'
|
returnn/frontend/array_.py
CHANGED
|
@@ -429,11 +429,27 @@ def concat(
|
|
|
429
429
|
assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True"
|
|
430
430
|
if not out_dim:
|
|
431
431
|
out_dim = sum(d for _, d in sources)
|
|
432
|
+
# noinspection PyProtectedMember
|
|
433
|
+
out = sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim)
|
|
432
434
|
if handle_dynamic_dims is None or handle_dynamic_dims:
|
|
435
|
+
need_to_handle = False
|
|
433
436
|
for src, dim in sources[:-1]:
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
+
if dim.need_masking():
|
|
438
|
+
need_to_handle = True
|
|
439
|
+
if need_to_handle:
|
|
440
|
+
masks = []
|
|
441
|
+
for _, dim in sources:
|
|
442
|
+
masks.append(
|
|
443
|
+
dim.get_mask(dim_order=(dim,) + dim.dyn_size_ext.dims, device=out.device)
|
|
444
|
+
if dim.need_masking()
|
|
445
|
+
else rf.constant(True, dims=[dim], device=out.device)
|
|
446
|
+
)
|
|
447
|
+
# noinspection PyProtectedMember
|
|
448
|
+
mask_concat = sources[0][0]._raw_backend.concat(
|
|
449
|
+
*[(mask, dim) for (_, dim), mask in zip(sources, masks)], allow_broadcast=True, out_dim=out_dim
|
|
450
|
+
)
|
|
451
|
+
out, out_dim = rf.masked_select(out, mask=mask_concat, dims=[out_dim])
|
|
452
|
+
return out, out_dim
|
|
437
453
|
|
|
438
454
|
|
|
439
455
|
def concat_features(*sources: Tensor, allow_broadcast=False) -> Tensor:
|
|
@@ -8,6 +8,8 @@ https://github.com/rwth-i6/returnn_common/issues/233
|
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
from typing import Optional, Union, Any, Tuple, List, Dict, Callable
|
|
11
|
+
from types import FunctionType
|
|
12
|
+
import functools
|
|
11
13
|
import copy as _copy
|
|
12
14
|
from returnn.tensor import Tensor, Dim
|
|
13
15
|
import returnn.frontend as rf
|
|
@@ -298,7 +300,8 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
298
300
|
*,
|
|
299
301
|
num_layers: int,
|
|
300
302
|
input_layer: Optional[Union[ConformerConvSubsample, ISeqDownsamplingEncoder, rf.Module, Any]],
|
|
301
|
-
input_embedding_scale: float =
|
|
303
|
+
input_embedding_scale: Optional[float] = None,
|
|
304
|
+
pos_enc: Union[None, Callable, Dict[str, Any], rf.Module] = None,
|
|
302
305
|
input_dropout: float = 0.1,
|
|
303
306
|
ff_dim: Dim = NotSpecified,
|
|
304
307
|
ff_activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] = NotSpecified,
|
|
@@ -317,8 +320,17 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
317
320
|
:param num_layers: the number of encoder layers
|
|
318
321
|
:param input_layer: input/frontend/prenet with potential subsampling.
|
|
319
322
|
(x, in_spatial_dim) -> (y, out_spatial_dim)
|
|
320
|
-
:param input_embedding_scale: applied after input_layer.
|
|
321
|
-
|
|
323
|
+
:param input_embedding_scale: applied after input_layer.
|
|
324
|
+
1.0 by default for historic reasons if pos_enc is None,
|
|
325
|
+
else sqrt(out_dim) by default.
|
|
326
|
+
In std Transformer, also ESPnet E-Branchformer and Conformer, this is sqrt(out_dim),
|
|
327
|
+
which is relevant when you add positional encoding.
|
|
328
|
+
:param pos_enc: positional encoding, applied after input_embedding_scale.
|
|
329
|
+
None (no positional encoding) by default, unlike standard Transformer.
|
|
330
|
+
E.g. :func:`rf.sinusoidal_positional_encoding` for absolute pos enc.
|
|
331
|
+
Note, relative positional encoding is usually part of the attention layer,
|
|
332
|
+
e.g. :class:`rf.RelPosSelfAttention`,
|
|
333
|
+
and nothing needs to be set here.
|
|
322
334
|
:param input_dropout: applied after input_projection(input_layer(x))
|
|
323
335
|
:param ff_dim: the dimension of feed-forward layers. 2048 originally, or 4 times out_dim
|
|
324
336
|
:param ff_activation: activation function for feed-forward network
|
|
@@ -352,12 +364,22 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
352
364
|
else:
|
|
353
365
|
raise TypeError(f"unexpected input_layer {input_layer!r}")
|
|
354
366
|
self.input_layer = input_layer
|
|
355
|
-
self.
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
else
|
|
359
|
-
)
|
|
367
|
+
in_dim = self.input_layer.out_dim if self.input_layer else self.in_dim
|
|
368
|
+
self.input_projection = rf.Linear(in_dim, self.out_dim, with_bias=False) if in_dim != self.out_dim else None
|
|
369
|
+
if input_embedding_scale is None:
|
|
370
|
+
input_embedding_scale = (self.out_dim.dimension**0.5) if pos_enc is not None else 1.0
|
|
360
371
|
self.input_embedding_scale = input_embedding_scale
|
|
372
|
+
if pos_enc is None:
|
|
373
|
+
pass
|
|
374
|
+
elif isinstance(pos_enc, dict):
|
|
375
|
+
pos_enc = rf.build_from_dict(pos_enc, feat_dim=self.out_dim)
|
|
376
|
+
elif isinstance(pos_enc, rf.Module):
|
|
377
|
+
pass
|
|
378
|
+
elif isinstance(pos_enc, FunctionType):
|
|
379
|
+
pos_enc = functools.partial(pos_enc, feat_dim=self.out_dim)
|
|
380
|
+
else:
|
|
381
|
+
raise TypeError(f"unexpected pos_enc type {pos_enc!r}")
|
|
382
|
+
self.pos_enc = pos_enc
|
|
361
383
|
self.input_dropout = input_dropout
|
|
362
384
|
|
|
363
385
|
if not encoder_layer or isinstance(encoder_layer, (dict, type)):
|
|
@@ -411,6 +433,8 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
411
433
|
x = self.input_projection(x_subsample) if self.input_projection else x_subsample
|
|
412
434
|
if self.input_embedding_scale != 1.0:
|
|
413
435
|
x = x * self.input_embedding_scale
|
|
436
|
+
if self.pos_enc is not None:
|
|
437
|
+
x = x + self.pos_enc(spatial_dim=out_spatial_dim)
|
|
414
438
|
x = rf.dropout(x, self.input_dropout, axis=self.dropout_broadcast and self.out_dim)
|
|
415
439
|
x = self.layers(x, spatial_dim=out_spatial_dim, collected_outputs=collected_outputs)
|
|
416
440
|
return x, out_spatial_dim
|
returnn/torch/updater.py
CHANGED
|
@@ -440,8 +440,6 @@ class Updater:
|
|
|
440
440
|
If it's a str, it must be the optimizer name.
|
|
441
441
|
:return: tuple (optimizer, optional optimizer_param_groups_extra_opts).
|
|
442
442
|
"""
|
|
443
|
-
lr = self.learning_rate
|
|
444
|
-
|
|
445
443
|
# If the parameter is already a valid optimizer, return it without further processing
|
|
446
444
|
if isinstance(optimizer_opts, torch.optim.Optimizer):
|
|
447
445
|
return optimizer_opts, None
|
|
@@ -467,8 +465,9 @@ class Updater:
|
|
|
467
465
|
opt_kwargs["eps"] = opt_kwargs.pop("epsilon")
|
|
468
466
|
if "learning_rate" in opt_kwargs or "lr" in opt_kwargs:
|
|
469
467
|
raise ValueError("'learning_rate' should be set outside of the 'optimizer' dict.")
|
|
470
|
-
lr
|
|
471
|
-
|
|
468
|
+
# lr will anyway be updated in set_current_train_step / _update_effective_learning_rate,
|
|
469
|
+
# so this value doesn't really matter here
|
|
470
|
+
opt_kwargs["lr"] = self.learning_rate
|
|
472
471
|
|
|
473
472
|
param_groups = self._get_optimizer_param_groups(optim_class, opt_kwargs)
|
|
474
473
|
param_groups = list(param_groups)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=40gq368ieKKYuTEYmkvPOiakDG6N6_RPtPd5pBGCoQ0,5213
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=b0sPsBSCJObrNcEG_PgvBTHyCMvourjvAfvgeL7uEAk,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,84
|
|
|
80
80
|
returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
|
|
81
81
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
82
82
|
returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
|
|
83
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/array_.py,sha256=5aCU-BCBH035QZlpYSRGtVXhxz78tteZ75e57FxCIRw,52182
|
|
84
84
|
returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
|
|
85
85
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
86
86
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -138,7 +138,7 @@ returnn/frontend/decoder/__init__.py,sha256=A-koKyPVlXp_V_2bk6GKZ1Xfv4rYIcfxGMXQ
|
|
|
138
138
|
returnn/frontend/decoder/transformer.py,sha256=20a37hMiPbQBHx3tSbOeiAbFPVRcX_KYpPuw8tmY6GU,23658
|
|
139
139
|
returnn/frontend/encoder/__init__.py,sha256=0QGLlujRIKx3zBREeShza_-xhGIxj73zbd7t-g1m-ho,17
|
|
140
140
|
returnn/frontend/encoder/base.py,sha256=A759EwCYAmSi-kzXz1vaTjR2l59TvNGQlzaNdp3UOKs,2109
|
|
141
|
-
returnn/frontend/encoder/conformer.py,sha256=
|
|
141
|
+
returnn/frontend/encoder/conformer.py,sha256=rWulygolesbYkLw9naSxwygaZhWqKpHKEVj-1AQbel0,21351
|
|
142
142
|
returnn/frontend/encoder/conformer_v2.py,sha256=vAYdT8m2Zzg3IIZZafeccClFHU1_c9T-EgBOsHadQPA,7701
|
|
143
143
|
returnn/frontend/encoder/e_branchformer.py,sha256=SZdhpb90FaQdpzgvSOtFPLbLCa0NdycbB5Z4vMoY4TM,12279
|
|
144
144
|
returnn/frontend/encoder/transformer.py,sha256=Jj0mF1D2MohOk-9sGYdsLtVW_86fwoq4pKWCdPMvPR8,11580
|
|
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
|
|
210
210
|
returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
|
|
211
|
-
returnn/torch/updater.py,sha256=
|
|
211
|
+
returnn/torch/updater.py,sha256=nNd1mBPQyvIB096BEFi0KKmRI-U3jnRETzb743p2B9c,32064
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
|
|
214
214
|
returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250828.2732.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250828.2732.dist-info/METADATA,sha256=40gq368ieKKYuTEYmkvPOiakDG6N6_RPtPd5pBGCoQ0,5213
|
|
258
|
+
returnn-1.20250828.2732.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250828.2732.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250828.2732.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|