returnn 1.20251013.113026__py3-none-any.whl → 1.20260109.93428__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +2 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/config.py +1 -1
- returnn/datasets/distrib_files.py +53 -1
- returnn/datasets/generating.py +3 -5
- returnn/datasets/lm.py +20 -0
- returnn/datasets/meta.py +179 -60
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/array_.py +46 -0
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +40 -1
- returnn/frontend/math_.py +54 -14
- returnn/frontend/module.py +8 -1
- returnn/frontend/nested.py +5 -0
- returnn/native_op.cpp +80 -0
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +39 -24
- returnn/tensor/utils.py +7 -4
- returnn/tf/frontend_layers/_backend.py +4 -3
- returnn/tf/layers/basic.py +15 -39
- returnn/tf/native_op.py +11 -58
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +19 -0
- returnn/torch/engine.py +67 -2
- returnn/torch/frontend/_backend.py +135 -13
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/collect_outputs_dict.py +79 -0
- returnn/util/debug.py +11 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/task_system.py +1 -1
- {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/METADATA +2 -2
- {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/RECORD +43 -42
- {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/LICENSE +0 -0
- {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/WHEEL +0 -0
- {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/top_level.txt +0 -0
returnn/frontend/conv.py
CHANGED
|
@@ -3,7 +3,7 @@ Convolution, transposed convolution, pooling
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import Optional, Sequence, Tuple,
|
|
6
|
+
from typing import Optional, Union, TypeVar, Sequence, Tuple, List
|
|
7
7
|
from returnn.util.basic import next_type_attrib_in_mro_chain
|
|
8
8
|
from returnn.tensor import Tensor, Dim
|
|
9
9
|
import returnn.frontend as rf
|
|
@@ -25,6 +25,9 @@ __all__ = [
|
|
|
25
25
|
"pool2d",
|
|
26
26
|
"pool3d",
|
|
27
27
|
"make_conv_out_spatial_dims",
|
|
28
|
+
"calc_conv_out_length",
|
|
29
|
+
"make_transposed_conv_out_spatial_dims",
|
|
30
|
+
"calc_transposed_conv_out_length",
|
|
28
31
|
]
|
|
29
32
|
|
|
30
33
|
|
|
@@ -396,7 +399,11 @@ def transposed_conv(
|
|
|
396
399
|
)
|
|
397
400
|
if use_mask:
|
|
398
401
|
source = source.copy_masked(0, dims=in_spatial_dims)
|
|
399
|
-
if
|
|
402
|
+
if (
|
|
403
|
+
padding == "same"
|
|
404
|
+
and any(s != 1 for s in (strides or [fs.dimension for fs in filter_size]))
|
|
405
|
+
and _should_use_consistent_same_padding()
|
|
406
|
+
):
|
|
400
407
|
# I don't really know what this should mean here... Investigate this further...
|
|
401
408
|
raise NotImplementedError("consistent same padding not implemented for transposed conv")
|
|
402
409
|
# noinspection PyProtectedMember
|
|
@@ -424,6 +431,39 @@ class TransposedConv1d(_TransposedConv):
|
|
|
424
431
|
|
|
425
432
|
nd = 1
|
|
426
433
|
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
in_dim: Dim,
|
|
437
|
+
out_dim: Dim,
|
|
438
|
+
filter_size: Union[int, Dim],
|
|
439
|
+
*,
|
|
440
|
+
padding: str,
|
|
441
|
+
remove_padding: int = 0,
|
|
442
|
+
output_padding: Optional[int] = None,
|
|
443
|
+
strides: Optional[int] = None,
|
|
444
|
+
with_bias: bool = True,
|
|
445
|
+
):
|
|
446
|
+
"""
|
|
447
|
+
:param in_dim:
|
|
448
|
+
:param out_dim:
|
|
449
|
+
:param filter_size:
|
|
450
|
+
:param strides: specifies the upscaling. by default, same as filter_size
|
|
451
|
+
:param padding: "same" or "valid"
|
|
452
|
+
:param remove_padding:
|
|
453
|
+
:param output_padding:
|
|
454
|
+
:param with_bias: whether to add a bias. enabled by default
|
|
455
|
+
"""
|
|
456
|
+
super().__init__(
|
|
457
|
+
in_dim=in_dim,
|
|
458
|
+
out_dim=out_dim,
|
|
459
|
+
filter_size=[filter_size],
|
|
460
|
+
padding=padding,
|
|
461
|
+
remove_padding=remove_padding,
|
|
462
|
+
output_padding=output_padding,
|
|
463
|
+
strides=[strides] if strides is not None else None,
|
|
464
|
+
with_bias=with_bias,
|
|
465
|
+
)
|
|
466
|
+
|
|
427
467
|
__call__ = _ConvOrTransposedConv._call_nd1
|
|
428
468
|
|
|
429
469
|
|
|
@@ -704,7 +744,7 @@ def make_conv_out_spatial_dims(
|
|
|
704
744
|
strides: Union[Sequence[int], int] = 1,
|
|
705
745
|
dilation_rate: Union[Sequence[int], int] = 1,
|
|
706
746
|
description_prefix: Optional[str] = None,
|
|
707
|
-
) ->
|
|
747
|
+
) -> List[Dim]:
|
|
708
748
|
"""create out spatial dims from in spatial dims"""
|
|
709
749
|
nd = len(in_spatial_dims)
|
|
710
750
|
if isinstance(filter_size, (int, Dim)):
|
|
@@ -715,84 +755,263 @@ def make_conv_out_spatial_dims(
|
|
|
715
755
|
strides = [strides] * nd
|
|
716
756
|
if isinstance(dilation_rate, int):
|
|
717
757
|
dilation_rate = [dilation_rate] * nd
|
|
718
|
-
assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate)
|
|
719
758
|
if isinstance(padding, (int, str)):
|
|
720
759
|
padding = [padding] * nd
|
|
760
|
+
assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate) == len(padding)
|
|
721
761
|
padding = [p.lower() if isinstance(p, str) else p for p in padding]
|
|
722
762
|
out_spatial_dims = []
|
|
723
763
|
for i in range(nd):
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
):
|
|
728
|
-
out_spatial_dims.append(in_spatial_dim)
|
|
729
|
-
else:
|
|
730
|
-
out_spatial_dim = _calc_out_dim(
|
|
731
|
-
in_dim=in_spatial_dim,
|
|
764
|
+
out_spatial_dims.append(
|
|
765
|
+
calc_conv_out_length(
|
|
766
|
+
in_spatial_dims[i],
|
|
732
767
|
filter_size=filter_size[i],
|
|
768
|
+
padding=padding[i],
|
|
733
769
|
stride=strides[i],
|
|
734
770
|
dilation_rate=dilation_rate[i],
|
|
735
|
-
|
|
771
|
+
name=f"{description_prefix}:spatial{i}" if description_prefix else None,
|
|
736
772
|
)
|
|
737
|
-
|
|
738
|
-
if description_prefix and out_spatial_dim != in_spatial_dim:
|
|
739
|
-
out_spatial_dim.name = f"{description_prefix}:spatial{i}"
|
|
740
|
-
if in_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext is None:
|
|
741
|
-
out_spatial_dim.dyn_size_ext = _calc_out_dim(
|
|
742
|
-
in_dim=in_spatial_dim.dyn_size_ext,
|
|
743
|
-
filter_size=filter_size[i],
|
|
744
|
-
stride=strides[i],
|
|
745
|
-
dilation_rate=dilation_rate[i],
|
|
746
|
-
padding=padding[i],
|
|
747
|
-
)
|
|
748
|
-
out_spatial_dims.append(out_spatial_dim)
|
|
773
|
+
)
|
|
749
774
|
return out_spatial_dims
|
|
750
775
|
|
|
751
776
|
|
|
752
|
-
|
|
777
|
+
T = TypeVar("T", int, Dim, Tensor)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def calc_conv_out_length(
|
|
781
|
+
in_length: Union[T, int, Dim, Tensor],
|
|
782
|
+
*,
|
|
783
|
+
filter_size: Union[T, int, Dim, Tensor],
|
|
784
|
+
stride: int,
|
|
785
|
+
padding: Union[str, int],
|
|
786
|
+
dilation_rate: int = 1,
|
|
787
|
+
name: Optional[str] = None,
|
|
788
|
+
) -> T:
|
|
753
789
|
"""
|
|
754
790
|
Copied and adapted from TF ConvLayer.calc_out_dim.
|
|
755
791
|
|
|
756
|
-
:param T
|
|
757
|
-
:param
|
|
758
|
-
:param
|
|
759
|
-
:param
|
|
760
|
-
:param
|
|
792
|
+
:param T in_length: dimension in some axis
|
|
793
|
+
:param filter_size: e.g. 2, for the corresponding axis
|
|
794
|
+
:param stride: e.g. 1, for the corresponding axis
|
|
795
|
+
:param dilation_rate: e.g. 1
|
|
796
|
+
:param padding: "valid" or "same" or int
|
|
797
|
+
:param name:
|
|
761
798
|
:return: the output dimension
|
|
762
|
-
:rtype: T
|
|
763
799
|
"""
|
|
800
|
+
padding = padding.lower() if isinstance(padding, str) else padding
|
|
801
|
+
if isinstance(filter_size, int):
|
|
802
|
+
filter_size_int = filter_size
|
|
803
|
+
elif isinstance(filter_size, Dim):
|
|
804
|
+
filter_size_int = filter_size.dimension
|
|
805
|
+
else:
|
|
806
|
+
filter_size_int = None
|
|
807
|
+
filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
|
|
764
808
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
:param T|int|Tensor|torch.Tensor|tensorflow.Tensor a:
|
|
768
|
-
:param T|int|Tensor|torch.Tensor|tensorflow.Tensor b:
|
|
769
|
-
:rtype: T
|
|
770
|
-
"""
|
|
771
|
-
if isinstance(b, int) and b == 1:
|
|
772
|
-
return a
|
|
773
|
-
if isinstance(a, Tensor):
|
|
774
|
-
return rf.ceil_divide(a, b)
|
|
775
|
-
return -(-a // b)
|
|
809
|
+
if (filter_size_int == stride == 1 and padding in ("valid", "same", 0)) or (stride == 1 and padding == "same"):
|
|
810
|
+
return in_length
|
|
776
811
|
|
|
777
|
-
padding = padding.lower() if isinstance(padding, str) else padding
|
|
778
812
|
# See tf.compat.v1.nn.convolution() documentation for more.
|
|
779
813
|
if padding == "same":
|
|
780
|
-
if isinstance(
|
|
781
|
-
|
|
782
|
-
|
|
814
|
+
if isinstance(in_length, Dim):
|
|
815
|
+
out_length = in_length.ceildiv_right(stride)
|
|
816
|
+
else:
|
|
817
|
+
out_length = _ceildiv(in_length, stride)
|
|
783
818
|
elif padding == "valid" or isinstance(padding, int):
|
|
784
819
|
if isinstance(padding, int) and padding != 0:
|
|
785
820
|
assert padding > 0
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
821
|
+
in_length = padding + in_length + padding
|
|
822
|
+
|
|
823
|
+
if filter_size_int == 1:
|
|
824
|
+
valid_part = in_length
|
|
825
|
+
elif isinstance(in_length, Dim):
|
|
826
|
+
filter_left_dilated = (filter_size_ - 1) * dilation_rate // 2
|
|
827
|
+
filter_right_dilated = (filter_size_ - 1) * dilation_rate - filter_left_dilated
|
|
828
|
+
valid_part = in_length.sub_left(filter_left_dilated).sub_right(filter_right_dilated)
|
|
829
|
+
else:
|
|
830
|
+
valid_part = in_length - (filter_size_ - 1) * dilation_rate
|
|
831
|
+
|
|
832
|
+
if isinstance(valid_part, Dim):
|
|
833
|
+
out_length = valid_part.ceildiv_right(stride)
|
|
834
|
+
else:
|
|
835
|
+
out_length = _ceildiv(valid_part, stride)
|
|
836
|
+
|
|
793
837
|
else:
|
|
794
838
|
raise ValueError(f"invalid padding {padding!r} (type {type(padding).__name__})")
|
|
795
839
|
|
|
840
|
+
if isinstance(in_length, Dim):
|
|
841
|
+
assert isinstance(out_length, Dim)
|
|
842
|
+
if name and out_length != in_length:
|
|
843
|
+
out_length.name = name
|
|
844
|
+
if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
|
|
845
|
+
out_dyn_size_ext = calc_conv_out_length(
|
|
846
|
+
in_length=in_length.dyn_size_ext,
|
|
847
|
+
filter_size=filter_size,
|
|
848
|
+
stride=stride,
|
|
849
|
+
dilation_rate=dilation_rate,
|
|
850
|
+
padding=padding,
|
|
851
|
+
)
|
|
852
|
+
assert isinstance(out_dyn_size_ext, Tensor)
|
|
853
|
+
out_length.dyn_size_ext = out_dyn_size_ext
|
|
854
|
+
|
|
855
|
+
return out_length
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def make_transposed_conv_out_spatial_dims(
|
|
859
|
+
in_spatial_dims: Sequence[Dim],
|
|
860
|
+
*,
|
|
861
|
+
filter_size: Union[Sequence[Union[int, Dim]], int, Dim],
|
|
862
|
+
padding: Union[str, int, Sequence[int]],
|
|
863
|
+
output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
|
|
864
|
+
strides: Union[Sequence[Optional[int]], None, int] = None,
|
|
865
|
+
dilation_rate: Union[Sequence[int], int] = 1,
|
|
866
|
+
description_prefix: Optional[str] = None,
|
|
867
|
+
) -> List[Dim]:
|
|
868
|
+
"""create out spatial dims from in spatial dims"""
|
|
869
|
+
nd = len(in_spatial_dims)
|
|
870
|
+
if isinstance(filter_size, (int, Dim)):
|
|
871
|
+
filter_size = [filter_size] * nd
|
|
872
|
+
filter_size = [d.dimension if isinstance(d, Dim) else d for d in filter_size]
|
|
873
|
+
assert all(isinstance(s, int) for s in filter_size)
|
|
874
|
+
if isinstance(strides, int) or strides is None:
|
|
875
|
+
strides = [strides] * nd
|
|
876
|
+
if isinstance(dilation_rate, int):
|
|
877
|
+
dilation_rate = [dilation_rate] * nd
|
|
878
|
+
if isinstance(padding, (int, str)):
|
|
879
|
+
padding = [padding] * nd
|
|
880
|
+
if isinstance(output_padding, int) or output_padding is None:
|
|
881
|
+
output_padding = [output_padding] * nd
|
|
882
|
+
assert (
|
|
883
|
+
nd
|
|
884
|
+
== len(in_spatial_dims)
|
|
885
|
+
== len(filter_size)
|
|
886
|
+
== len(strides)
|
|
887
|
+
== len(dilation_rate)
|
|
888
|
+
== len(padding)
|
|
889
|
+
== len(output_padding)
|
|
890
|
+
)
|
|
891
|
+
padding = [p.lower() if isinstance(p, str) else p for p in padding]
|
|
892
|
+
out_spatial_dims = []
|
|
893
|
+
for i in range(nd):
|
|
894
|
+
out_spatial_dims.append(
|
|
895
|
+
calc_transposed_conv_out_length(
|
|
896
|
+
in_spatial_dims[i],
|
|
897
|
+
filter_size=filter_size[i],
|
|
898
|
+
padding=padding[i],
|
|
899
|
+
stride=strides[i],
|
|
900
|
+
dilation_rate=dilation_rate[i],
|
|
901
|
+
name=f"{description_prefix}:spatial{i}" if description_prefix else None,
|
|
902
|
+
)
|
|
903
|
+
)
|
|
904
|
+
return out_spatial_dims
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def calc_transposed_conv_out_length(
|
|
908
|
+
in_length: Union[T, int, Dim, Tensor],
|
|
909
|
+
*,
|
|
910
|
+
filter_size: Union[int, Dim],
|
|
911
|
+
padding: Union[int, str],
|
|
912
|
+
output_padding: Optional[int] = None,
|
|
913
|
+
stride: Optional[int] = None,
|
|
914
|
+
dilation_rate: int = 1,
|
|
915
|
+
name: Optional[str] = None,
|
|
916
|
+
) -> T:
|
|
917
|
+
"""
|
|
918
|
+
Determines output length of a transposed convolution given input length.
|
|
919
|
+
|
|
920
|
+
Copied from TF/Keras conv_utils.deconv_output_length
|
|
921
|
+
(https://github.com/tensorflow/tensorflow/blob/5912f51d580551e5cee2cfde4cb882594b4d3e60/tensorflow/python/keras/utils/conv_utils.py#L140),
|
|
922
|
+
adapted with simplification.
|
|
923
|
+
|
|
924
|
+
Also see :func:`calc_conv_out_length`.
|
|
925
|
+
|
|
926
|
+
:param in_length:
|
|
927
|
+
:param filter_size:
|
|
928
|
+
:param padding: one of `"same"`, `"valid"`, `"full"`.
|
|
929
|
+
:param output_padding: amount of padding along the output dimension.
|
|
930
|
+
Can be set to `None` in which case the output length is inferred.
|
|
931
|
+
:param stride:
|
|
932
|
+
:param dilation_rate:
|
|
933
|
+
:param name:
|
|
934
|
+
:returns: The output length (integer)
|
|
935
|
+
"""
|
|
936
|
+
assert padding in {"same", "valid", "full"} or isinstance(padding, int)
|
|
937
|
+
|
|
938
|
+
if isinstance(filter_size, int):
|
|
939
|
+
filter_size_int = filter_size
|
|
940
|
+
elif isinstance(filter_size, Dim):
|
|
941
|
+
filter_size_int = filter_size.dimension
|
|
942
|
+
else:
|
|
943
|
+
filter_size_int = None
|
|
944
|
+
filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
|
|
945
|
+
|
|
946
|
+
# Get the dilated kernel size
|
|
947
|
+
if dilation_rate != 1 and filter_size_int != 1:
|
|
948
|
+
filter_size = filter_size + (filter_size_ - 1) * (dilation_rate - 1)
|
|
949
|
+
|
|
950
|
+
if stride is None:
|
|
951
|
+
assert filter_size_int is not None
|
|
952
|
+
stride = filter_size_int
|
|
953
|
+
if stride != 1:
|
|
954
|
+
in_length = in_length * stride
|
|
955
|
+
|
|
956
|
+
# Infer length if output padding is None, else compute the exact length
|
|
957
|
+
if output_padding is None:
|
|
958
|
+
if padding == "valid" or padding == 0:
|
|
959
|
+
if filter_size_int == stride:
|
|
960
|
+
out_length = in_length
|
|
961
|
+
elif filter_size_int is not None:
|
|
962
|
+
out_length = in_length + max(filter_size_int - stride, 0)
|
|
963
|
+
elif isinstance(filter_size, Tensor):
|
|
964
|
+
out_length = in_length + rf.relu(filter_size - stride)
|
|
965
|
+
elif isinstance(filter_size, Dim):
|
|
966
|
+
out_length = in_length + (filter_size - stride)
|
|
967
|
+
else:
|
|
968
|
+
raise ValueError(f"invalid filter_size {filter_size!r} type {type(filter_size)}")
|
|
969
|
+
elif padding == "full":
|
|
970
|
+
out_length = in_length - (stride + filter_size_ - 2)
|
|
971
|
+
elif padding == "same":
|
|
972
|
+
out_length = in_length
|
|
973
|
+
else:
|
|
974
|
+
raise ValueError(f"invalid padding {padding!r}")
|
|
975
|
+
|
|
976
|
+
else: # output_padding
|
|
977
|
+
if padding == "same":
|
|
978
|
+
pad = filter_size // 2
|
|
979
|
+
elif padding == "valid":
|
|
980
|
+
pad = 0
|
|
981
|
+
elif padding == "full":
|
|
982
|
+
pad = filter_size - 1
|
|
983
|
+
elif isinstance(padding, int):
|
|
984
|
+
pad = padding
|
|
985
|
+
else:
|
|
986
|
+
raise ValueError(f"invalid padding {padding!r}")
|
|
987
|
+
out_length = in_length + (filter_size - stride - 2 * pad + output_padding)
|
|
988
|
+
|
|
989
|
+
if isinstance(in_length, Dim):
|
|
990
|
+
assert isinstance(out_length, Dim)
|
|
991
|
+
if name and out_length != in_length:
|
|
992
|
+
out_length.name = name
|
|
993
|
+
if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
|
|
994
|
+
out_dyn_size_ext = calc_transposed_conv_out_length(
|
|
995
|
+
in_length=in_length.dyn_size_ext,
|
|
996
|
+
filter_size=filter_size,
|
|
997
|
+
padding=padding,
|
|
998
|
+
output_padding=output_padding,
|
|
999
|
+
stride=stride,
|
|
1000
|
+
dilation_rate=dilation_rate,
|
|
1001
|
+
)
|
|
1002
|
+
assert isinstance(out_dyn_size_ext, Tensor)
|
|
1003
|
+
out_length.dyn_size_ext = out_dyn_size_ext
|
|
1004
|
+
|
|
1005
|
+
return out_length
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def _ceildiv(a: T, b: Union[T, int, Tensor]) -> T:
|
|
1009
|
+
if isinstance(b, int) and b == 1:
|
|
1010
|
+
return a
|
|
1011
|
+
if isinstance(a, Tensor):
|
|
1012
|
+
return rf.ceil_divide(a, b)
|
|
1013
|
+
return -(-a // b)
|
|
1014
|
+
|
|
796
1015
|
|
|
797
1016
|
def _should_use_consistent_same_padding() -> bool:
|
|
798
1017
|
"""
|
returnn/frontend/device.py
CHANGED
|
@@ -8,7 +8,13 @@ from contextlib import contextmanager
|
|
|
8
8
|
from returnn.tensor import Tensor
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"copy_to_device",
|
|
13
|
+
"get_default_device",
|
|
14
|
+
"set_default_device",
|
|
15
|
+
"set_default_device_ctx",
|
|
16
|
+
"get_default_dim_size_device",
|
|
17
|
+
]
|
|
12
18
|
|
|
13
19
|
|
|
14
20
|
_default_device: Optional[str] = None
|
|
@@ -61,3 +67,10 @@ def set_default_device_ctx(device: Optional[str]):
|
|
|
61
67
|
yield
|
|
62
68
|
finally:
|
|
63
69
|
_default_device = old_device
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_default_dim_size_device() -> Optional[str]:
|
|
73
|
+
"""
|
|
74
|
+
:return: default device, where to put new tensors for dim sizes (Dim.dyn_size_ext)
|
|
75
|
+
"""
|
|
76
|
+
return "cpu"
|
|
@@ -167,6 +167,25 @@ class ConformerConvSubsample(ISeqDownsamplingEncoder):
|
|
|
167
167
|
out, _ = rf.merge_dims(x, dims=[self._final_second_spatial_dim, in_dim])
|
|
168
168
|
return out, in_spatial_dims[0]
|
|
169
169
|
|
|
170
|
+
def get_out_spatial_dim(self, in_spatial_dim: Dim) -> Dim:
|
|
171
|
+
"""Get output spatial dimension given input spatial dimension."""
|
|
172
|
+
out_spatial_dim = in_spatial_dim
|
|
173
|
+
for i, conv_layer in enumerate(self.conv_layers):
|
|
174
|
+
(out_spatial_dim,) = rf.make_conv_out_spatial_dims(
|
|
175
|
+
[out_spatial_dim],
|
|
176
|
+
filter_size=conv_layer.filter_size[0],
|
|
177
|
+
strides=conv_layer.strides[0],
|
|
178
|
+
padding=conv_layer.padding,
|
|
179
|
+
)
|
|
180
|
+
if self.pool_sizes and i < len(self.pool_sizes):
|
|
181
|
+
(out_spatial_dim,) = rf.make_conv_out_spatial_dims(
|
|
182
|
+
[out_spatial_dim],
|
|
183
|
+
filter_size=self.pool_sizes[i][0],
|
|
184
|
+
strides=self.pool_sizes[i][0],
|
|
185
|
+
padding="same",
|
|
186
|
+
)
|
|
187
|
+
return out_spatial_dim
|
|
188
|
+
|
|
170
189
|
|
|
171
190
|
class ConformerEncoderLayer(rf.Module):
|
|
172
191
|
"""
|
|
@@ -273,6 +292,7 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
273
292
|
x_mhsa = self.self_att(x_mhsa_ln, axis=spatial_dim)
|
|
274
293
|
x_mhsa = rf.dropout(x_mhsa, self.dropout, axis=self.dropout_broadcast and self.out_dim)
|
|
275
294
|
x_mhsa_out = x_mhsa + x_ffn1_out
|
|
295
|
+
del x_mhsa
|
|
276
296
|
|
|
277
297
|
# Conv
|
|
278
298
|
x_conv_ln = self.conv_layer_norm(x_mhsa_out)
|
|
@@ -79,6 +79,8 @@ class TransformerEncoder(rf.Module):
|
|
|
79
79
|
self.model_dim = model_dim
|
|
80
80
|
self.embed_dim = embed_dim
|
|
81
81
|
|
|
82
|
+
self.out_dim = self.model_dim # alias. consistency, compatibility
|
|
83
|
+
|
|
82
84
|
if input_embedding is None or isinstance(input_embedding, rf.Module):
|
|
83
85
|
pass
|
|
84
86
|
elif isinstance(input_embedding, type):
|
returnn/frontend/loss.py
CHANGED
|
@@ -3,11 +3,12 @@ Loss functions
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
+
from typing import Optional, Tuple
|
|
6
7
|
from returnn.tensor import Tensor, Dim
|
|
7
8
|
import returnn.frontend as rf
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
__all__ = ["cross_entropy", "ctc_loss", "edit_distance"]
|
|
11
|
+
__all__ = ["cross_entropy", "ctc_loss", "ctc_greedy_decode", "edit_distance"]
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def cross_entropy(
|
|
@@ -93,6 +94,44 @@ def ctc_loss(
|
|
|
93
94
|
)
|
|
94
95
|
|
|
95
96
|
|
|
97
|
+
def ctc_greedy_decode(
|
|
98
|
+
logits: Tensor,
|
|
99
|
+
*,
|
|
100
|
+
in_spatial_dim: Dim,
|
|
101
|
+
blank_index: int,
|
|
102
|
+
out_spatial_dim: Optional[Dim] = None,
|
|
103
|
+
target_dim: Optional[Dim] = None,
|
|
104
|
+
wb_target_dim: Optional[Dim] = None,
|
|
105
|
+
) -> Tuple[Tensor, Dim]:
|
|
106
|
+
"""
|
|
107
|
+
Greedy CTC decode.
|
|
108
|
+
|
|
109
|
+
:return: (labels, out_spatial_dim)
|
|
110
|
+
"""
|
|
111
|
+
if wb_target_dim is None:
|
|
112
|
+
assert logits.feature_dim
|
|
113
|
+
wb_target_dim = logits.feature_dim
|
|
114
|
+
|
|
115
|
+
labels = rf.reduce_argmax(logits, axis=wb_target_dim)
|
|
116
|
+
labels = rf.cast(labels, "int32")
|
|
117
|
+
|
|
118
|
+
labels_shifted = rf.shift_right(labels, axis=in_spatial_dim, pad_value=blank_index)
|
|
119
|
+
mask_repeat = labels != labels_shifted
|
|
120
|
+
labels, out_spatial_dim = rf.masked_select(
|
|
121
|
+
labels,
|
|
122
|
+
mask=(labels != blank_index) & mask_repeat,
|
|
123
|
+
dims=[in_spatial_dim],
|
|
124
|
+
out_dim=out_spatial_dim,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if target_dim:
|
|
128
|
+
# Set correct sparse_dim. Only currently implemented if blank comes after.
|
|
129
|
+
assert target_dim.dimension == blank_index
|
|
130
|
+
labels.sparse_dim = target_dim
|
|
131
|
+
|
|
132
|
+
return labels, out_spatial_dim
|
|
133
|
+
|
|
134
|
+
|
|
96
135
|
def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, *, dtype: str = "int32") -> Tensor:
|
|
97
136
|
"""
|
|
98
137
|
:param a: [B,Ta]
|
returnn/frontend/math_.py
CHANGED
|
@@ -3,7 +3,6 @@ Math ops
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
import typing
|
|
7
6
|
from typing import Optional, Sequence, Union, Tuple, overload
|
|
8
7
|
import numpy
|
|
9
8
|
from returnn.tensor import Tensor, Dim
|
|
@@ -77,7 +76,7 @@ __all__ = [
|
|
|
77
76
|
]
|
|
78
77
|
|
|
79
78
|
|
|
80
|
-
@
|
|
79
|
+
@overload
|
|
81
80
|
def compare(
|
|
82
81
|
a: Tensor,
|
|
83
82
|
kind: str,
|
|
@@ -86,7 +85,19 @@ def compare(
|
|
|
86
85
|
allow_broadcast_all_sources: Optional[bool] = None,
|
|
87
86
|
dim_order: Optional[Sequence[Dim]] = None,
|
|
88
87
|
) -> Tensor:
|
|
89
|
-
"""compare
|
|
88
|
+
"""compare"""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@overload
|
|
92
|
+
def compare(
|
|
93
|
+
a: Union[Tensor, _RawTensorTypes],
|
|
94
|
+
kind: str,
|
|
95
|
+
b: Union[Tensor, _RawTensorTypes],
|
|
96
|
+
*,
|
|
97
|
+
allow_broadcast_all_sources: Optional[bool] = None,
|
|
98
|
+
dim_order: Optional[Sequence[Dim]] = None,
|
|
99
|
+
) -> Tensor:
|
|
100
|
+
"""compare"""
|
|
90
101
|
|
|
91
102
|
|
|
92
103
|
_CompareMap = {
|
|
@@ -138,7 +149,7 @@ def compare_bc(
|
|
|
138
149
|
return compare(a, kind, b, allow_broadcast_all_sources=True, dim_order=dim_order)
|
|
139
150
|
|
|
140
151
|
|
|
141
|
-
@
|
|
152
|
+
@overload
|
|
142
153
|
def combine(
|
|
143
154
|
a: Tensor,
|
|
144
155
|
kind: str,
|
|
@@ -147,7 +158,19 @@ def combine(
|
|
|
147
158
|
allow_broadcast_all_sources: Optional[bool] = None,
|
|
148
159
|
dim_order: Optional[Sequence[Dim]] = None,
|
|
149
160
|
) -> Tensor:
|
|
150
|
-
"""combine
|
|
161
|
+
"""combine"""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@overload
|
|
165
|
+
def combine(
|
|
166
|
+
a: Union[Tensor, _RawTensorTypes],
|
|
167
|
+
kind: str,
|
|
168
|
+
b: Union[Tensor, _RawTensorTypes],
|
|
169
|
+
*,
|
|
170
|
+
allow_broadcast_all_sources: Optional[bool] = None,
|
|
171
|
+
dim_order: Optional[Sequence[Dim]] = None,
|
|
172
|
+
) -> Union[Tensor, _RawTensorTypes]:
|
|
173
|
+
"""combine"""
|
|
151
174
|
|
|
152
175
|
|
|
153
176
|
_CombineMap = {
|
|
@@ -332,7 +355,12 @@ def logical_not(a: Tensor) -> Tensor:
|
|
|
332
355
|
|
|
333
356
|
@overload
|
|
334
357
|
def opt_logical_or(a: bool, b: bool) -> bool:
|
|
335
|
-
"""logical or"""
|
|
358
|
+
"""opt logical or"""
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
@overload
|
|
362
|
+
def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
363
|
+
"""opt logical or"""
|
|
336
364
|
|
|
337
365
|
|
|
338
366
|
def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
@@ -350,7 +378,12 @@ def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tens
|
|
|
350
378
|
|
|
351
379
|
@overload
|
|
352
380
|
def opt_logical_and(a: bool, b: bool) -> bool:
|
|
353
|
-
"""logical and"""
|
|
381
|
+
"""opt logical and"""
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
@overload
|
|
385
|
+
def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
386
|
+
"""opt logical and"""
|
|
354
387
|
|
|
355
388
|
|
|
356
389
|
def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
@@ -416,16 +449,23 @@ def minimum(a: Tensor, b: Union[Tensor, _RawTensorTypes], *other_tensors) -> Ten
|
|
|
416
449
|
|
|
417
450
|
def clip_by_value(
|
|
418
451
|
x: Tensor,
|
|
419
|
-
clip_value_min: Union[Tensor, _RawTensorTypes],
|
|
420
|
-
clip_value_max: Union[Tensor, _RawTensorTypes],
|
|
452
|
+
clip_value_min: Union[None, Tensor, _RawTensorTypes] = None,
|
|
453
|
+
clip_value_max: Union[None, Tensor, _RawTensorTypes] = None,
|
|
421
454
|
*,
|
|
422
455
|
allow_broadcast_all_sources: bool = False,
|
|
423
456
|
) -> Tensor:
|
|
424
457
|
"""clip by value"""
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
x
|
|
428
|
-
|
|
458
|
+
if clip_value_min is not None and clip_value_max is not None:
|
|
459
|
+
# noinspection PyProtectedMember
|
|
460
|
+
return x._raw_backend.clip_by_value(
|
|
461
|
+
x, clip_value_min, clip_value_max, allow_broadcast_all_sources=allow_broadcast_all_sources
|
|
462
|
+
)
|
|
463
|
+
elif clip_value_min is not None and clip_value_max is None:
|
|
464
|
+
return maximum(x, clip_value_min)
|
|
465
|
+
elif clip_value_min is None and clip_value_max is not None:
|
|
466
|
+
return minimum(x, clip_value_max)
|
|
467
|
+
else:
|
|
468
|
+
return x
|
|
429
469
|
|
|
430
470
|
|
|
431
471
|
def identity(x: Tensor) -> Tensor:
|
|
@@ -541,7 +581,7 @@ def floor(a: Tensor) -> Tensor:
|
|
|
541
581
|
|
|
542
582
|
# noinspection PyShadowingBuiltins
|
|
543
583
|
def round(a: Tensor) -> Tensor:
|
|
544
|
-
"""round"""
|
|
584
|
+
"""round. the result dtype is same as input dtype, still float"""
|
|
545
585
|
# noinspection PyProtectedMember
|
|
546
586
|
return a._raw_backend.activation(a, "round")
|
|
547
587
|
|
returnn/frontend/module.py
CHANGED
|
@@ -274,10 +274,17 @@ class Functional(Module):
|
|
|
274
274
|
(This is often not necessary, but sometimes useful.)
|
|
275
275
|
"""
|
|
276
276
|
|
|
277
|
-
def __init__(self, func):
|
|
277
|
+
def __init__(self, func, *, attribs: Optional[Dict[str, Any]] = None):
|
|
278
|
+
"""
|
|
279
|
+
:param func: callable. you might want to use functools.partial if you want to fix some arguments.
|
|
280
|
+
:param attribs: optional dict of attributes to set on this module. e.g. ``out_dim``.
|
|
281
|
+
"""
|
|
278
282
|
super().__init__()
|
|
279
283
|
assert callable(func)
|
|
280
284
|
self.func = func
|
|
285
|
+
if attribs:
|
|
286
|
+
for k, v in attribs.items():
|
|
287
|
+
setattr(self, k, v)
|
|
281
288
|
|
|
282
289
|
def __repr__(self):
|
|
283
290
|
return f"{self.__class__.__name__}({self.func.__qualname__})"
|