returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/config.py +1 -1
- returnn/datasets/basic.py +29 -13
- returnn/datasets/distrib_files.py +61 -3
- returnn/datasets/generating.py +12 -21
- returnn/datasets/huggingface.py +434 -0
- returnn/datasets/lm.py +20 -0
- returnn/datasets/meta.py +179 -60
- returnn/datasets/multi_proc.py +1 -1
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/text_dict.py +1 -1
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/_backend.py +7 -0
- returnn/frontend/array_.py +54 -1
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/decoder/transformer.py +36 -17
- returnn/frontend/encoder/conformer.py +1 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +40 -1
- returnn/frontend/module.py +8 -1
- returnn/frontend/nested.py +9 -0
- returnn/native_op.cpp +80 -0
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +51 -29
- returnn/tensor/_tensor_extra.py +6 -1
- returnn/tensor/utils.py +7 -4
- returnn/tf/frontend_layers/_backend.py +11 -2
- returnn/tf/frontend_low_level/_backend.py +15 -0
- returnn/tf/layers/basic.py +16 -38
- returnn/tf/native_op.py +11 -58
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +19 -0
- returnn/torch/data/returnn_dataset_wrapper.py +9 -3
- returnn/torch/engine.py +67 -2
- returnn/torch/frontend/_backend.py +119 -7
- returnn/torch/util/diagnose_gpu.py +65 -31
- returnn/torch/util/exception_helper.py +7 -1
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/collect_outputs_dict.py +79 -0
- returnn/util/debug.py +11 -2
- returnn/util/file_cache.py +42 -4
- returnn/util/task_system.py +1 -1
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
returnn/frontend/conv.py
CHANGED
|
@@ -3,7 +3,7 @@ Convolution, transposed convolution, pooling
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import Optional, Sequence, Tuple,
|
|
6
|
+
from typing import Optional, Union, TypeVar, Sequence, Tuple, List
|
|
7
7
|
from returnn.util.basic import next_type_attrib_in_mro_chain
|
|
8
8
|
from returnn.tensor import Tensor, Dim
|
|
9
9
|
import returnn.frontend as rf
|
|
@@ -25,6 +25,9 @@ __all__ = [
|
|
|
25
25
|
"pool2d",
|
|
26
26
|
"pool3d",
|
|
27
27
|
"make_conv_out_spatial_dims",
|
|
28
|
+
"calc_conv_out_length",
|
|
29
|
+
"make_transposed_conv_out_spatial_dims",
|
|
30
|
+
"calc_transposed_conv_out_length",
|
|
28
31
|
]
|
|
29
32
|
|
|
30
33
|
|
|
@@ -396,7 +399,11 @@ def transposed_conv(
|
|
|
396
399
|
)
|
|
397
400
|
if use_mask:
|
|
398
401
|
source = source.copy_masked(0, dims=in_spatial_dims)
|
|
399
|
-
if
|
|
402
|
+
if (
|
|
403
|
+
padding == "same"
|
|
404
|
+
and any(s != 1 for s in (strides or [fs.dimension for fs in filter_size]))
|
|
405
|
+
and _should_use_consistent_same_padding()
|
|
406
|
+
):
|
|
400
407
|
# I don't really know what this should mean here... Investigate this further...
|
|
401
408
|
raise NotImplementedError("consistent same padding not implemented for transposed conv")
|
|
402
409
|
# noinspection PyProtectedMember
|
|
@@ -424,6 +431,39 @@ class TransposedConv1d(_TransposedConv):
|
|
|
424
431
|
|
|
425
432
|
nd = 1
|
|
426
433
|
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
in_dim: Dim,
|
|
437
|
+
out_dim: Dim,
|
|
438
|
+
filter_size: Union[int, Dim],
|
|
439
|
+
*,
|
|
440
|
+
padding: str,
|
|
441
|
+
remove_padding: int = 0,
|
|
442
|
+
output_padding: Optional[int] = None,
|
|
443
|
+
strides: Optional[int] = None,
|
|
444
|
+
with_bias: bool = True,
|
|
445
|
+
):
|
|
446
|
+
"""
|
|
447
|
+
:param in_dim:
|
|
448
|
+
:param out_dim:
|
|
449
|
+
:param filter_size:
|
|
450
|
+
:param strides: specifies the upscaling. by default, same as filter_size
|
|
451
|
+
:param padding: "same" or "valid"
|
|
452
|
+
:param remove_padding:
|
|
453
|
+
:param output_padding:
|
|
454
|
+
:param with_bias: whether to add a bias. enabled by default
|
|
455
|
+
"""
|
|
456
|
+
super().__init__(
|
|
457
|
+
in_dim=in_dim,
|
|
458
|
+
out_dim=out_dim,
|
|
459
|
+
filter_size=[filter_size],
|
|
460
|
+
padding=padding,
|
|
461
|
+
remove_padding=remove_padding,
|
|
462
|
+
output_padding=output_padding,
|
|
463
|
+
strides=[strides] if strides is not None else None,
|
|
464
|
+
with_bias=with_bias,
|
|
465
|
+
)
|
|
466
|
+
|
|
427
467
|
__call__ = _ConvOrTransposedConv._call_nd1
|
|
428
468
|
|
|
429
469
|
|
|
@@ -704,7 +744,7 @@ def make_conv_out_spatial_dims(
|
|
|
704
744
|
strides: Union[Sequence[int], int] = 1,
|
|
705
745
|
dilation_rate: Union[Sequence[int], int] = 1,
|
|
706
746
|
description_prefix: Optional[str] = None,
|
|
707
|
-
) ->
|
|
747
|
+
) -> List[Dim]:
|
|
708
748
|
"""create out spatial dims from in spatial dims"""
|
|
709
749
|
nd = len(in_spatial_dims)
|
|
710
750
|
if isinstance(filter_size, (int, Dim)):
|
|
@@ -715,84 +755,263 @@ def make_conv_out_spatial_dims(
|
|
|
715
755
|
strides = [strides] * nd
|
|
716
756
|
if isinstance(dilation_rate, int):
|
|
717
757
|
dilation_rate = [dilation_rate] * nd
|
|
718
|
-
assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate)
|
|
719
758
|
if isinstance(padding, (int, str)):
|
|
720
759
|
padding = [padding] * nd
|
|
760
|
+
assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate) == len(padding)
|
|
721
761
|
padding = [p.lower() if isinstance(p, str) else p for p in padding]
|
|
722
762
|
out_spatial_dims = []
|
|
723
763
|
for i in range(nd):
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
):
|
|
728
|
-
out_spatial_dims.append(in_spatial_dim)
|
|
729
|
-
else:
|
|
730
|
-
out_spatial_dim = _calc_out_dim(
|
|
731
|
-
in_dim=in_spatial_dim,
|
|
764
|
+
out_spatial_dims.append(
|
|
765
|
+
calc_conv_out_length(
|
|
766
|
+
in_spatial_dims[i],
|
|
732
767
|
filter_size=filter_size[i],
|
|
768
|
+
padding=padding[i],
|
|
733
769
|
stride=strides[i],
|
|
734
770
|
dilation_rate=dilation_rate[i],
|
|
735
|
-
|
|
771
|
+
name=f"{description_prefix}:spatial{i}" if description_prefix else None,
|
|
736
772
|
)
|
|
737
|
-
|
|
738
|
-
if description_prefix and out_spatial_dim != in_spatial_dim:
|
|
739
|
-
out_spatial_dim.name = f"{description_prefix}:spatial{i}"
|
|
740
|
-
if in_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext is None:
|
|
741
|
-
out_spatial_dim.dyn_size_ext = _calc_out_dim(
|
|
742
|
-
in_dim=in_spatial_dim.dyn_size_ext,
|
|
743
|
-
filter_size=filter_size[i],
|
|
744
|
-
stride=strides[i],
|
|
745
|
-
dilation_rate=dilation_rate[i],
|
|
746
|
-
padding=padding[i],
|
|
747
|
-
)
|
|
748
|
-
out_spatial_dims.append(out_spatial_dim)
|
|
773
|
+
)
|
|
749
774
|
return out_spatial_dims
|
|
750
775
|
|
|
751
776
|
|
|
752
|
-
|
|
777
|
+
T = TypeVar("T", int, Dim, Tensor)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def calc_conv_out_length(
|
|
781
|
+
in_length: Union[T, int, Dim, Tensor],
|
|
782
|
+
*,
|
|
783
|
+
filter_size: Union[T, int, Dim, Tensor],
|
|
784
|
+
stride: int,
|
|
785
|
+
padding: Union[str, int],
|
|
786
|
+
dilation_rate: int = 1,
|
|
787
|
+
name: Optional[str] = None,
|
|
788
|
+
) -> T:
|
|
753
789
|
"""
|
|
754
790
|
Copied and adapted from TF ConvLayer.calc_out_dim.
|
|
755
791
|
|
|
756
|
-
:param T
|
|
757
|
-
:param
|
|
758
|
-
:param
|
|
759
|
-
:param
|
|
760
|
-
:param
|
|
792
|
+
:param T in_length: dimension in some axis
|
|
793
|
+
:param filter_size: e.g. 2, for the corresponding axis
|
|
794
|
+
:param stride: e.g. 1, for the corresponding axis
|
|
795
|
+
:param dilation_rate: e.g. 1
|
|
796
|
+
:param padding: "valid" or "same" or int
|
|
797
|
+
:param name:
|
|
761
798
|
:return: the output dimension
|
|
762
|
-
:rtype: T
|
|
763
799
|
"""
|
|
800
|
+
padding = padding.lower() if isinstance(padding, str) else padding
|
|
801
|
+
if isinstance(filter_size, int):
|
|
802
|
+
filter_size_int = filter_size
|
|
803
|
+
elif isinstance(filter_size, Dim):
|
|
804
|
+
filter_size_int = filter_size.dimension
|
|
805
|
+
else:
|
|
806
|
+
filter_size_int = None
|
|
807
|
+
filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
|
|
764
808
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
:param T|int|Tensor|torch.Tensor|tensorflow.Tensor a:
|
|
768
|
-
:param T|int|Tensor|torch.Tensor|tensorflow.Tensor b:
|
|
769
|
-
:rtype: T
|
|
770
|
-
"""
|
|
771
|
-
if isinstance(b, int) and b == 1:
|
|
772
|
-
return a
|
|
773
|
-
if isinstance(a, Tensor):
|
|
774
|
-
return rf.ceil_divide(a, b)
|
|
775
|
-
return -(-a // b)
|
|
809
|
+
if (filter_size_int == stride == 1 and padding in ("valid", "same", 0)) or (stride == 1 and padding == "same"):
|
|
810
|
+
return in_length
|
|
776
811
|
|
|
777
|
-
padding = padding.lower() if isinstance(padding, str) else padding
|
|
778
812
|
# See tf.compat.v1.nn.convolution() documentation for more.
|
|
779
813
|
if padding == "same":
|
|
780
|
-
if isinstance(
|
|
781
|
-
|
|
782
|
-
|
|
814
|
+
if isinstance(in_length, Dim):
|
|
815
|
+
out_length = in_length.ceildiv_right(stride)
|
|
816
|
+
else:
|
|
817
|
+
out_length = _ceildiv(in_length, stride)
|
|
783
818
|
elif padding == "valid" or isinstance(padding, int):
|
|
784
819
|
if isinstance(padding, int) and padding != 0:
|
|
785
820
|
assert padding > 0
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
821
|
+
in_length = padding + in_length + padding
|
|
822
|
+
|
|
823
|
+
if filter_size_int == 1:
|
|
824
|
+
valid_part = in_length
|
|
825
|
+
elif isinstance(in_length, Dim):
|
|
826
|
+
filter_left_dilated = (filter_size_ - 1) * dilation_rate // 2
|
|
827
|
+
filter_right_dilated = (filter_size_ - 1) * dilation_rate - filter_left_dilated
|
|
828
|
+
valid_part = in_length.sub_left(filter_left_dilated).sub_right(filter_right_dilated)
|
|
829
|
+
else:
|
|
830
|
+
valid_part = in_length - (filter_size_ - 1) * dilation_rate
|
|
831
|
+
|
|
832
|
+
if isinstance(valid_part, Dim):
|
|
833
|
+
out_length = valid_part.ceildiv_right(stride)
|
|
834
|
+
else:
|
|
835
|
+
out_length = _ceildiv(valid_part, stride)
|
|
836
|
+
|
|
793
837
|
else:
|
|
794
838
|
raise ValueError(f"invalid padding {padding!r} (type {type(padding).__name__})")
|
|
795
839
|
|
|
840
|
+
if isinstance(in_length, Dim):
|
|
841
|
+
assert isinstance(out_length, Dim)
|
|
842
|
+
if name and out_length != in_length:
|
|
843
|
+
out_length.name = name
|
|
844
|
+
if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
|
|
845
|
+
out_dyn_size_ext = calc_conv_out_length(
|
|
846
|
+
in_length=in_length.dyn_size_ext,
|
|
847
|
+
filter_size=filter_size,
|
|
848
|
+
stride=stride,
|
|
849
|
+
dilation_rate=dilation_rate,
|
|
850
|
+
padding=padding,
|
|
851
|
+
)
|
|
852
|
+
assert isinstance(out_dyn_size_ext, Tensor)
|
|
853
|
+
out_length.dyn_size_ext = out_dyn_size_ext
|
|
854
|
+
|
|
855
|
+
return out_length
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def make_transposed_conv_out_spatial_dims(
|
|
859
|
+
in_spatial_dims: Sequence[Dim],
|
|
860
|
+
*,
|
|
861
|
+
filter_size: Union[Sequence[Union[int, Dim]], int, Dim],
|
|
862
|
+
padding: Union[str, int, Sequence[int]],
|
|
863
|
+
output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
|
|
864
|
+
strides: Union[Sequence[Optional[int]], None, int] = None,
|
|
865
|
+
dilation_rate: Union[Sequence[int], int] = 1,
|
|
866
|
+
description_prefix: Optional[str] = None,
|
|
867
|
+
) -> List[Dim]:
|
|
868
|
+
"""create out spatial dims from in spatial dims"""
|
|
869
|
+
nd = len(in_spatial_dims)
|
|
870
|
+
if isinstance(filter_size, (int, Dim)):
|
|
871
|
+
filter_size = [filter_size] * nd
|
|
872
|
+
filter_size = [d.dimension if isinstance(d, Dim) else d for d in filter_size]
|
|
873
|
+
assert all(isinstance(s, int) for s in filter_size)
|
|
874
|
+
if isinstance(strides, int) or strides is None:
|
|
875
|
+
strides = [strides] * nd
|
|
876
|
+
if isinstance(dilation_rate, int):
|
|
877
|
+
dilation_rate = [dilation_rate] * nd
|
|
878
|
+
if isinstance(padding, (int, str)):
|
|
879
|
+
padding = [padding] * nd
|
|
880
|
+
if isinstance(output_padding, int) or output_padding is None:
|
|
881
|
+
output_padding = [output_padding] * nd
|
|
882
|
+
assert (
|
|
883
|
+
nd
|
|
884
|
+
== len(in_spatial_dims)
|
|
885
|
+
== len(filter_size)
|
|
886
|
+
== len(strides)
|
|
887
|
+
== len(dilation_rate)
|
|
888
|
+
== len(padding)
|
|
889
|
+
== len(output_padding)
|
|
890
|
+
)
|
|
891
|
+
padding = [p.lower() if isinstance(p, str) else p for p in padding]
|
|
892
|
+
out_spatial_dims = []
|
|
893
|
+
for i in range(nd):
|
|
894
|
+
out_spatial_dims.append(
|
|
895
|
+
calc_transposed_conv_out_length(
|
|
896
|
+
in_spatial_dims[i],
|
|
897
|
+
filter_size=filter_size[i],
|
|
898
|
+
padding=padding[i],
|
|
899
|
+
stride=strides[i],
|
|
900
|
+
dilation_rate=dilation_rate[i],
|
|
901
|
+
name=f"{description_prefix}:spatial{i}" if description_prefix else None,
|
|
902
|
+
)
|
|
903
|
+
)
|
|
904
|
+
return out_spatial_dims
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def calc_transposed_conv_out_length(
|
|
908
|
+
in_length: Union[T, int, Dim, Tensor],
|
|
909
|
+
*,
|
|
910
|
+
filter_size: Union[int, Dim],
|
|
911
|
+
padding: Union[int, str],
|
|
912
|
+
output_padding: Optional[int] = None,
|
|
913
|
+
stride: Optional[int] = None,
|
|
914
|
+
dilation_rate: int = 1,
|
|
915
|
+
name: Optional[str] = None,
|
|
916
|
+
) -> T:
|
|
917
|
+
"""
|
|
918
|
+
Determines output length of a transposed convolution given input length.
|
|
919
|
+
|
|
920
|
+
Copied from TF/Keras conv_utils.deconv_output_length
|
|
921
|
+
(https://github.com/tensorflow/tensorflow/blob/5912f51d580551e5cee2cfde4cb882594b4d3e60/tensorflow/python/keras/utils/conv_utils.py#L140),
|
|
922
|
+
adapted with simplification.
|
|
923
|
+
|
|
924
|
+
Also see :func:`calc_conv_out_length`.
|
|
925
|
+
|
|
926
|
+
:param in_length:
|
|
927
|
+
:param filter_size:
|
|
928
|
+
:param padding: one of `"same"`, `"valid"`, `"full"`.
|
|
929
|
+
:param output_padding: amount of padding along the output dimension.
|
|
930
|
+
Can be set to `None` in which case the output length is inferred.
|
|
931
|
+
:param stride:
|
|
932
|
+
:param dilation_rate:
|
|
933
|
+
:param name:
|
|
934
|
+
:returns: The output length (integer)
|
|
935
|
+
"""
|
|
936
|
+
assert padding in {"same", "valid", "full"} or isinstance(padding, int)
|
|
937
|
+
|
|
938
|
+
if isinstance(filter_size, int):
|
|
939
|
+
filter_size_int = filter_size
|
|
940
|
+
elif isinstance(filter_size, Dim):
|
|
941
|
+
filter_size_int = filter_size.dimension
|
|
942
|
+
else:
|
|
943
|
+
filter_size_int = None
|
|
944
|
+
filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
|
|
945
|
+
|
|
946
|
+
# Get the dilated kernel size
|
|
947
|
+
if dilation_rate != 1 and filter_size_int != 1:
|
|
948
|
+
filter_size = filter_size + (filter_size_ - 1) * (dilation_rate - 1)
|
|
949
|
+
|
|
950
|
+
if stride is None:
|
|
951
|
+
assert filter_size_int is not None
|
|
952
|
+
stride = filter_size_int
|
|
953
|
+
if stride != 1:
|
|
954
|
+
in_length = in_length * stride
|
|
955
|
+
|
|
956
|
+
# Infer length if output padding is None, else compute the exact length
|
|
957
|
+
if output_padding is None:
|
|
958
|
+
if padding == "valid" or padding == 0:
|
|
959
|
+
if filter_size_int == stride:
|
|
960
|
+
out_length = in_length
|
|
961
|
+
elif filter_size_int is not None:
|
|
962
|
+
out_length = in_length + max(filter_size_int - stride, 0)
|
|
963
|
+
elif isinstance(filter_size, Tensor):
|
|
964
|
+
out_length = in_length + rf.relu(filter_size - stride)
|
|
965
|
+
elif isinstance(filter_size, Dim):
|
|
966
|
+
out_length = in_length + (filter_size - stride)
|
|
967
|
+
else:
|
|
968
|
+
raise ValueError(f"invalid filter_size {filter_size!r} type {type(filter_size)}")
|
|
969
|
+
elif padding == "full":
|
|
970
|
+
out_length = in_length - (stride + filter_size_ - 2)
|
|
971
|
+
elif padding == "same":
|
|
972
|
+
out_length = in_length
|
|
973
|
+
else:
|
|
974
|
+
raise ValueError(f"invalid padding {padding!r}")
|
|
975
|
+
|
|
976
|
+
else: # output_padding
|
|
977
|
+
if padding == "same":
|
|
978
|
+
pad = filter_size // 2
|
|
979
|
+
elif padding == "valid":
|
|
980
|
+
pad = 0
|
|
981
|
+
elif padding == "full":
|
|
982
|
+
pad = filter_size - 1
|
|
983
|
+
elif isinstance(padding, int):
|
|
984
|
+
pad = padding
|
|
985
|
+
else:
|
|
986
|
+
raise ValueError(f"invalid padding {padding!r}")
|
|
987
|
+
out_length = in_length + (filter_size - stride - 2 * pad + output_padding)
|
|
988
|
+
|
|
989
|
+
if isinstance(in_length, Dim):
|
|
990
|
+
assert isinstance(out_length, Dim)
|
|
991
|
+
if name and out_length != in_length:
|
|
992
|
+
out_length.name = name
|
|
993
|
+
if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
|
|
994
|
+
out_dyn_size_ext = calc_transposed_conv_out_length(
|
|
995
|
+
in_length=in_length.dyn_size_ext,
|
|
996
|
+
filter_size=filter_size,
|
|
997
|
+
padding=padding,
|
|
998
|
+
output_padding=output_padding,
|
|
999
|
+
stride=stride,
|
|
1000
|
+
dilation_rate=dilation_rate,
|
|
1001
|
+
)
|
|
1002
|
+
assert isinstance(out_dyn_size_ext, Tensor)
|
|
1003
|
+
out_length.dyn_size_ext = out_dyn_size_ext
|
|
1004
|
+
|
|
1005
|
+
return out_length
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def _ceildiv(a: T, b: Union[T, int, Tensor]) -> T:
|
|
1009
|
+
if isinstance(b, int) and b == 1:
|
|
1010
|
+
return a
|
|
1011
|
+
if isinstance(a, Tensor):
|
|
1012
|
+
return rf.ceil_divide(a, b)
|
|
1013
|
+
return -(-a // b)
|
|
1014
|
+
|
|
796
1015
|
|
|
797
1016
|
def _should_use_consistent_same_padding() -> bool:
|
|
798
1017
|
"""
|
|
@@ -49,6 +49,7 @@ class TransformerDecoder(rf.Module):
|
|
|
49
49
|
layer_opts: Optional[Dict[str, Any]] = None,
|
|
50
50
|
embed_dim: Optional[Dim] = None,
|
|
51
51
|
share_embedding: bool = None,
|
|
52
|
+
input_embedding: bool = True,
|
|
52
53
|
input_embedding_scale: float = None,
|
|
53
54
|
input_dropout: float = None,
|
|
54
55
|
logits_with_bias: bool = False,
|
|
@@ -72,6 +73,7 @@ class TransformerDecoder(rf.Module):
|
|
|
72
73
|
:param layer_opts: options for the decoder layer
|
|
73
74
|
:param embed_dim: if given, will first have an embedding [vocab,embed] and then a linear [embed,model].
|
|
74
75
|
:param share_embedding:
|
|
76
|
+
:param input_embedding: whether to use input embedding. If False, you must provide input of dimension model_dim.
|
|
75
77
|
:param input_embedding_scale:
|
|
76
78
|
:param input_dropout:
|
|
77
79
|
:param logits_with_bias:
|
|
@@ -103,7 +105,7 @@ class TransformerDecoder(rf.Module):
|
|
|
103
105
|
|
|
104
106
|
# We could make this optional or configurable if we ever need to.
|
|
105
107
|
# Or maybe you would just have another separate implementation of this module then...
|
|
106
|
-
self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim)
|
|
108
|
+
self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim) if input_embedding else None
|
|
107
109
|
|
|
108
110
|
self.input_embedding_proj = None
|
|
109
111
|
if embed_dim:
|
|
@@ -121,21 +123,31 @@ class TransformerDecoder(rf.Module):
|
|
|
121
123
|
raise TypeError(f"unexpected pos_enc type {pos_enc!r}")
|
|
122
124
|
self.pos_enc = pos_enc
|
|
123
125
|
if share_embedding is None:
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
"
|
|
129
|
-
|
|
130
|
-
|
|
126
|
+
if embed_dim and embed_dim != model_dim:
|
|
127
|
+
share_embedding = False
|
|
128
|
+
elif input_embedding:
|
|
129
|
+
if BehaviorVersion.get() < 20:
|
|
130
|
+
logging.getLogger("returnn.frontend").warning(
|
|
131
|
+
"TransformerDecoder share_embedding default is False"
|
|
132
|
+
f" with your behavior version {BehaviorVersion.get()}."
|
|
133
|
+
" Explicitly set share_embedding or switch to a new behavior version >= 20."
|
|
134
|
+
)
|
|
135
|
+
share_embedding = True if BehaviorVersion.get() >= 20 else False
|
|
136
|
+
else: # not input_embedding
|
|
137
|
+
share_embedding = False
|
|
131
138
|
if input_embedding_scale is None:
|
|
132
|
-
if
|
|
133
|
-
|
|
134
|
-
"
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
+
if input_embedding:
|
|
140
|
+
if BehaviorVersion.get() < 20:
|
|
141
|
+
logging.getLogger("returnn.frontend").warning(
|
|
142
|
+
"TransformerDecoder input_embedding_scale default is suboptimal"
|
|
143
|
+
f" with your behavior version {BehaviorVersion.get()}."
|
|
144
|
+
" Explicitly set input_embedding_scale or switch to a new behavior version >= 20."
|
|
145
|
+
)
|
|
146
|
+
input_embedding_scale = model_dim.dimension**0.5 if BehaviorVersion.get() >= 20 else 1.0
|
|
147
|
+
elif pos_enc:
|
|
148
|
+
input_embedding_scale = model_dim.dimension**0.5
|
|
149
|
+
else:
|
|
150
|
+
input_embedding_scale = 1.0
|
|
139
151
|
self.input_embedding_scale = input_embedding_scale
|
|
140
152
|
if input_dropout is None:
|
|
141
153
|
if dropout > 0 and BehaviorVersion.get() < 20:
|
|
@@ -179,7 +191,9 @@ class TransformerDecoder(rf.Module):
|
|
|
179
191
|
self.logits = rf.Linear(model_dim, vocab_dim, with_bias=logits_with_bias)
|
|
180
192
|
|
|
181
193
|
if share_embedding:
|
|
182
|
-
assert
|
|
194
|
+
assert input_embedding, "input_embedding=True required for share_embedding"
|
|
195
|
+
assert not embed_dim or embed_dim == model_dim, f"{embed_dim=} not supported with share_embedding"
|
|
196
|
+
assert not logits_with_bias, "logits_with_bias=True expected with share_embedding"
|
|
183
197
|
self.logits.weight = self.input_embedding.weight
|
|
184
198
|
|
|
185
199
|
def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> rf.State:
|
|
@@ -219,7 +233,12 @@ class TransformerDecoder(rf.Module):
|
|
|
219
233
|
"""
|
|
220
234
|
new_state = rf.State()
|
|
221
235
|
|
|
222
|
-
|
|
236
|
+
if self.input_embedding is not None:
|
|
237
|
+
decoded = self.input_embedding(source)
|
|
238
|
+
else:
|
|
239
|
+
decoded = source
|
|
240
|
+
if self.input_embedding_scale != 1:
|
|
241
|
+
decoded = decoded * self.input_embedding_scale
|
|
223
242
|
if self.pos_enc is not None:
|
|
224
243
|
decoded = decoded + self.pos_enc(spatial_dim=spatial_dim, offset=state.pos)
|
|
225
244
|
decoded = rf.dropout(decoded, self.input_dropout)
|
|
@@ -273,6 +273,7 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
273
273
|
x_mhsa = self.self_att(x_mhsa_ln, axis=spatial_dim)
|
|
274
274
|
x_mhsa = rf.dropout(x_mhsa, self.dropout, axis=self.dropout_broadcast and self.out_dim)
|
|
275
275
|
x_mhsa_out = x_mhsa + x_ffn1_out
|
|
276
|
+
del x_mhsa
|
|
276
277
|
|
|
277
278
|
# Conv
|
|
278
279
|
x_conv_ln = self.conv_layer_norm(x_mhsa_out)
|
|
@@ -79,6 +79,8 @@ class TransformerEncoder(rf.Module):
|
|
|
79
79
|
self.model_dim = model_dim
|
|
80
80
|
self.embed_dim = embed_dim
|
|
81
81
|
|
|
82
|
+
self.out_dim = self.model_dim # alias. consistency, compatibility
|
|
83
|
+
|
|
82
84
|
if input_embedding is None or isinstance(input_embedding, rf.Module):
|
|
83
85
|
pass
|
|
84
86
|
elif isinstance(input_embedding, type):
|
returnn/frontend/loss.py
CHANGED
|
@@ -3,11 +3,12 @@ Loss functions
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
+
from typing import Optional, Tuple
|
|
6
7
|
from returnn.tensor import Tensor, Dim
|
|
7
8
|
import returnn.frontend as rf
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
__all__ = ["cross_entropy", "ctc_loss", "edit_distance"]
|
|
11
|
+
__all__ = ["cross_entropy", "ctc_loss", "ctc_greedy_decode", "edit_distance"]
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def cross_entropy(
|
|
@@ -93,6 +94,44 @@ def ctc_loss(
|
|
|
93
94
|
)
|
|
94
95
|
|
|
95
96
|
|
|
97
|
+
def ctc_greedy_decode(
|
|
98
|
+
logits: Tensor,
|
|
99
|
+
*,
|
|
100
|
+
in_spatial_dim: Dim,
|
|
101
|
+
blank_index: int,
|
|
102
|
+
out_spatial_dim: Optional[Dim] = None,
|
|
103
|
+
target_dim: Optional[Dim] = None,
|
|
104
|
+
wb_target_dim: Optional[Dim] = None,
|
|
105
|
+
) -> Tuple[Tensor, Dim]:
|
|
106
|
+
"""
|
|
107
|
+
Greedy CTC decode.
|
|
108
|
+
|
|
109
|
+
:return: (labels, out_spatial_dim)
|
|
110
|
+
"""
|
|
111
|
+
if wb_target_dim is None:
|
|
112
|
+
assert logits.feature_dim
|
|
113
|
+
wb_target_dim = logits.feature_dim
|
|
114
|
+
|
|
115
|
+
labels = rf.reduce_argmax(logits, axis=wb_target_dim)
|
|
116
|
+
labels = rf.cast(labels, "int32")
|
|
117
|
+
|
|
118
|
+
labels_shifted = rf.shift_right(labels, axis=in_spatial_dim, pad_value=blank_index)
|
|
119
|
+
mask_repeat = labels != labels_shifted
|
|
120
|
+
labels, out_spatial_dim = rf.masked_select(
|
|
121
|
+
labels,
|
|
122
|
+
mask=(labels != blank_index) & mask_repeat,
|
|
123
|
+
dims=[in_spatial_dim],
|
|
124
|
+
out_dim=out_spatial_dim,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if target_dim:
|
|
128
|
+
# Set correct sparse_dim. Only currently implemented if blank comes after.
|
|
129
|
+
assert target_dim.dimension == blank_index
|
|
130
|
+
labels.sparse_dim = target_dim
|
|
131
|
+
|
|
132
|
+
return labels, out_spatial_dim
|
|
133
|
+
|
|
134
|
+
|
|
96
135
|
def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, *, dtype: str = "int32") -> Tensor:
|
|
97
136
|
"""
|
|
98
137
|
:param a: [B,Ta]
|
returnn/frontend/module.py
CHANGED
|
@@ -274,10 +274,17 @@ class Functional(Module):
|
|
|
274
274
|
(This is often not necessary, but sometimes useful.)
|
|
275
275
|
"""
|
|
276
276
|
|
|
277
|
-
def __init__(self, func):
|
|
277
|
+
def __init__(self, func, *, attribs: Optional[Dict[str, Any]] = None):
|
|
278
|
+
"""
|
|
279
|
+
:param func: callable. you might want to use functools.partial if you want to fix some arguments.
|
|
280
|
+
:param attribs: optional dict of attributes to set on this module. e.g. ``out_dim``.
|
|
281
|
+
"""
|
|
278
282
|
super().__init__()
|
|
279
283
|
assert callable(func)
|
|
280
284
|
self.func = func
|
|
285
|
+
if attribs:
|
|
286
|
+
for k, v in attribs.items():
|
|
287
|
+
setattr(self, k, v)
|
|
281
288
|
|
|
282
289
|
def __repr__(self):
|
|
283
290
|
return f"{self.__class__.__name__}({self.func.__qualname__})"
|
returnn/frontend/nested.py
CHANGED
|
@@ -275,6 +275,8 @@ def _masked_select(
|
|
|
275
275
|
return s
|
|
276
276
|
assert s in dim_map
|
|
277
277
|
return dim_map[s]
|
|
278
|
+
if s is None:
|
|
279
|
+
return None
|
|
278
280
|
raise TypeError(f"_masked_select: unexpected type ({type(s)})")
|
|
279
281
|
|
|
280
282
|
|
|
@@ -346,6 +348,7 @@ def _masked_scatter_merge_dims(
|
|
|
346
348
|
merged_dim_map: Dict[Dim, Dim],
|
|
347
349
|
) -> T:
|
|
348
350
|
if isinstance(s, Dim):
|
|
351
|
+
assert isinstance(backup, Dim)
|
|
349
352
|
# This is slightly more complex than in the _masked_select case:
|
|
350
353
|
# We need to merge the s and backup depending on the mask.
|
|
351
354
|
if s in reverse_dim_map:
|
|
@@ -353,7 +356,10 @@ def _masked_scatter_merge_dims(
|
|
|
353
356
|
if s == backup:
|
|
354
357
|
return s
|
|
355
358
|
if s in merged_dim_map:
|
|
359
|
+
# If this assert fails, see e.g. https://github.com/rwth-i6/returnn/pull/1759 for an example.
|
|
360
|
+
assert backup in merged_dim_map, f"nested masked_scatter: mismatch of s {s} vs backup {backup}"
|
|
356
361
|
return merged_dim_map[s]
|
|
362
|
+
assert backup not in merged_dim_map, f"nested masked_scatter: mismatch of s {s} vs backup {backup}"
|
|
357
363
|
# Note: s/backup might even be static dims.
|
|
358
364
|
new_size = _masked_scatter(
|
|
359
365
|
s.get_size_tensor(),
|
|
@@ -416,6 +422,9 @@ def _masked_scatter(
|
|
|
416
422
|
if s in merged_dim_map:
|
|
417
423
|
return merged_dim_map[s]
|
|
418
424
|
return s
|
|
425
|
+
if s is None:
|
|
426
|
+
assert backup is None
|
|
427
|
+
return None
|
|
419
428
|
raise TypeError(f"_masked_scatter: unexpected type ({type(s)})")
|
|
420
429
|
|
|
421
430
|
|