returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +130 -42
- returnn/datasets/meta.py +93 -43
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +48 -2
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +222 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +182 -172
- returnn/native_op.py +36 -31
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +8 -5
- returnn/tf/frontend_layers/_backend.py +7 -3
- returnn/tf/layers/basic.py +27 -40
- returnn/tf/native_op.py +27 -63
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +22 -197
- returnn/torch/engine.py +157 -6
- returnn/torch/frontend/_backend.py +280 -29
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +12 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- returnn/util/task_system.py +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/frontend/conv.py
CHANGED
|
@@ -3,7 +3,7 @@ Convolution, transposed convolution, pooling
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import Optional, Sequence, Tuple,
|
|
6
|
+
from typing import Optional, Union, TypeVar, Sequence, Tuple, List
|
|
7
7
|
from returnn.util.basic import next_type_attrib_in_mro_chain
|
|
8
8
|
from returnn.tensor import Tensor, Dim
|
|
9
9
|
import returnn.frontend as rf
|
|
@@ -25,6 +25,9 @@ __all__ = [
|
|
|
25
25
|
"pool2d",
|
|
26
26
|
"pool3d",
|
|
27
27
|
"make_conv_out_spatial_dims",
|
|
28
|
+
"calc_conv_out_length",
|
|
29
|
+
"make_transposed_conv_out_spatial_dims",
|
|
30
|
+
"calc_transposed_conv_out_length",
|
|
28
31
|
]
|
|
29
32
|
|
|
30
33
|
|
|
@@ -396,7 +399,11 @@ def transposed_conv(
|
|
|
396
399
|
)
|
|
397
400
|
if use_mask:
|
|
398
401
|
source = source.copy_masked(0, dims=in_spatial_dims)
|
|
399
|
-
if
|
|
402
|
+
if (
|
|
403
|
+
padding == "same"
|
|
404
|
+
and any(s != 1 for s in (strides or [fs.dimension for fs in filter_size]))
|
|
405
|
+
and _should_use_consistent_same_padding()
|
|
406
|
+
):
|
|
400
407
|
# I don't really know what this should mean here... Investigate this further...
|
|
401
408
|
raise NotImplementedError("consistent same padding not implemented for transposed conv")
|
|
402
409
|
# noinspection PyProtectedMember
|
|
@@ -424,6 +431,39 @@ class TransposedConv1d(_TransposedConv):
|
|
|
424
431
|
|
|
425
432
|
nd = 1
|
|
426
433
|
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
in_dim: Dim,
|
|
437
|
+
out_dim: Dim,
|
|
438
|
+
filter_size: Union[int, Dim],
|
|
439
|
+
*,
|
|
440
|
+
padding: str,
|
|
441
|
+
remove_padding: int = 0,
|
|
442
|
+
output_padding: Optional[int] = None,
|
|
443
|
+
strides: Optional[int] = None,
|
|
444
|
+
with_bias: bool = True,
|
|
445
|
+
):
|
|
446
|
+
"""
|
|
447
|
+
:param in_dim:
|
|
448
|
+
:param out_dim:
|
|
449
|
+
:param filter_size:
|
|
450
|
+
:param strides: specifies the upscaling. by default, same as filter_size
|
|
451
|
+
:param padding: "same" or "valid"
|
|
452
|
+
:param remove_padding:
|
|
453
|
+
:param output_padding:
|
|
454
|
+
:param with_bias: whether to add a bias. enabled by default
|
|
455
|
+
"""
|
|
456
|
+
super().__init__(
|
|
457
|
+
in_dim=in_dim,
|
|
458
|
+
out_dim=out_dim,
|
|
459
|
+
filter_size=[filter_size],
|
|
460
|
+
padding=padding,
|
|
461
|
+
remove_padding=remove_padding,
|
|
462
|
+
output_padding=output_padding,
|
|
463
|
+
strides=[strides] if strides is not None else None,
|
|
464
|
+
with_bias=with_bias,
|
|
465
|
+
)
|
|
466
|
+
|
|
427
467
|
__call__ = _ConvOrTransposedConv._call_nd1
|
|
428
468
|
|
|
429
469
|
|
|
@@ -704,7 +744,7 @@ def make_conv_out_spatial_dims(
|
|
|
704
744
|
strides: Union[Sequence[int], int] = 1,
|
|
705
745
|
dilation_rate: Union[Sequence[int], int] = 1,
|
|
706
746
|
description_prefix: Optional[str] = None,
|
|
707
|
-
) ->
|
|
747
|
+
) -> List[Dim]:
|
|
708
748
|
"""create out spatial dims from in spatial dims"""
|
|
709
749
|
nd = len(in_spatial_dims)
|
|
710
750
|
if isinstance(filter_size, (int, Dim)):
|
|
@@ -715,84 +755,263 @@ def make_conv_out_spatial_dims(
|
|
|
715
755
|
strides = [strides] * nd
|
|
716
756
|
if isinstance(dilation_rate, int):
|
|
717
757
|
dilation_rate = [dilation_rate] * nd
|
|
718
|
-
assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate)
|
|
719
758
|
if isinstance(padding, (int, str)):
|
|
720
759
|
padding = [padding] * nd
|
|
760
|
+
assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate) == len(padding)
|
|
721
761
|
padding = [p.lower() if isinstance(p, str) else p for p in padding]
|
|
722
762
|
out_spatial_dims = []
|
|
723
763
|
for i in range(nd):
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
):
|
|
728
|
-
out_spatial_dims.append(in_spatial_dim)
|
|
729
|
-
else:
|
|
730
|
-
out_spatial_dim = _calc_out_dim(
|
|
731
|
-
in_dim=in_spatial_dim,
|
|
764
|
+
out_spatial_dims.append(
|
|
765
|
+
calc_conv_out_length(
|
|
766
|
+
in_spatial_dims[i],
|
|
732
767
|
filter_size=filter_size[i],
|
|
768
|
+
padding=padding[i],
|
|
733
769
|
stride=strides[i],
|
|
734
770
|
dilation_rate=dilation_rate[i],
|
|
735
|
-
|
|
771
|
+
name=f"{description_prefix}:spatial{i}" if description_prefix else None,
|
|
736
772
|
)
|
|
737
|
-
|
|
738
|
-
if description_prefix and out_spatial_dim != in_spatial_dim:
|
|
739
|
-
out_spatial_dim.name = f"{description_prefix}:spatial{i}"
|
|
740
|
-
if in_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext is None:
|
|
741
|
-
out_spatial_dim.dyn_size_ext = _calc_out_dim(
|
|
742
|
-
in_dim=in_spatial_dim.dyn_size_ext,
|
|
743
|
-
filter_size=filter_size[i],
|
|
744
|
-
stride=strides[i],
|
|
745
|
-
dilation_rate=dilation_rate[i],
|
|
746
|
-
padding=padding[i],
|
|
747
|
-
)
|
|
748
|
-
out_spatial_dims.append(out_spatial_dim)
|
|
773
|
+
)
|
|
749
774
|
return out_spatial_dims
|
|
750
775
|
|
|
751
776
|
|
|
752
|
-
|
|
777
|
+
T = TypeVar("T", int, Dim, Tensor)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def calc_conv_out_length(
|
|
781
|
+
in_length: Union[T, int, Dim, Tensor],
|
|
782
|
+
*,
|
|
783
|
+
filter_size: Union[T, int, Dim, Tensor],
|
|
784
|
+
stride: int,
|
|
785
|
+
padding: Union[str, int],
|
|
786
|
+
dilation_rate: int = 1,
|
|
787
|
+
name: Optional[str] = None,
|
|
788
|
+
) -> T:
|
|
753
789
|
"""
|
|
754
790
|
Copied and adapted from TF ConvLayer.calc_out_dim.
|
|
755
791
|
|
|
756
|
-
:param T
|
|
757
|
-
:param
|
|
758
|
-
:param
|
|
759
|
-
:param
|
|
760
|
-
:param
|
|
792
|
+
:param T in_length: dimension in some axis
|
|
793
|
+
:param filter_size: e.g. 2, for the corresponding axis
|
|
794
|
+
:param stride: e.g. 1, for the corresponding axis
|
|
795
|
+
:param dilation_rate: e.g. 1
|
|
796
|
+
:param padding: "valid" or "same" or int
|
|
797
|
+
:param name:
|
|
761
798
|
:return: the output dimension
|
|
762
|
-
:rtype: T
|
|
763
799
|
"""
|
|
800
|
+
padding = padding.lower() if isinstance(padding, str) else padding
|
|
801
|
+
if isinstance(filter_size, int):
|
|
802
|
+
filter_size_int = filter_size
|
|
803
|
+
elif isinstance(filter_size, Dim):
|
|
804
|
+
filter_size_int = filter_size.dimension
|
|
805
|
+
else:
|
|
806
|
+
filter_size_int = None
|
|
807
|
+
filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
|
|
764
808
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
:param T|int|Tensor|torch.Tensor|tensorflow.Tensor a:
|
|
768
|
-
:param T|int|Tensor|torch.Tensor|tensorflow.Tensor b:
|
|
769
|
-
:rtype: T
|
|
770
|
-
"""
|
|
771
|
-
if isinstance(b, int) and b == 1:
|
|
772
|
-
return a
|
|
773
|
-
if isinstance(a, Tensor):
|
|
774
|
-
return rf.ceil_divide(a, b)
|
|
775
|
-
return -(-a // b)
|
|
809
|
+
if (filter_size_int == stride == 1 and padding in ("valid", "same", 0)) or (stride == 1 and padding == "same"):
|
|
810
|
+
return in_length
|
|
776
811
|
|
|
777
|
-
padding = padding.lower() if isinstance(padding, str) else padding
|
|
778
812
|
# See tf.compat.v1.nn.convolution() documentation for more.
|
|
779
813
|
if padding == "same":
|
|
780
|
-
if isinstance(
|
|
781
|
-
|
|
782
|
-
|
|
814
|
+
if isinstance(in_length, Dim):
|
|
815
|
+
out_length = in_length.ceildiv_right(stride)
|
|
816
|
+
else:
|
|
817
|
+
out_length = _ceildiv(in_length, stride)
|
|
783
818
|
elif padding == "valid" or isinstance(padding, int):
|
|
784
819
|
if isinstance(padding, int) and padding != 0:
|
|
785
820
|
assert padding > 0
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
821
|
+
in_length = padding + in_length + padding
|
|
822
|
+
|
|
823
|
+
if filter_size_int == 1:
|
|
824
|
+
valid_part = in_length
|
|
825
|
+
elif isinstance(in_length, Dim):
|
|
826
|
+
filter_left_dilated = (filter_size_ - 1) * dilation_rate // 2
|
|
827
|
+
filter_right_dilated = (filter_size_ - 1) * dilation_rate - filter_left_dilated
|
|
828
|
+
valid_part = in_length.sub_left(filter_left_dilated).sub_right(filter_right_dilated)
|
|
829
|
+
else:
|
|
830
|
+
valid_part = in_length - (filter_size_ - 1) * dilation_rate
|
|
831
|
+
|
|
832
|
+
if isinstance(valid_part, Dim):
|
|
833
|
+
out_length = valid_part.ceildiv_right(stride)
|
|
834
|
+
else:
|
|
835
|
+
out_length = _ceildiv(valid_part, stride)
|
|
836
|
+
|
|
793
837
|
else:
|
|
794
838
|
raise ValueError(f"invalid padding {padding!r} (type {type(padding).__name__})")
|
|
795
839
|
|
|
840
|
+
if isinstance(in_length, Dim):
|
|
841
|
+
assert isinstance(out_length, Dim)
|
|
842
|
+
if name and out_length != in_length:
|
|
843
|
+
out_length.name = name
|
|
844
|
+
if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
|
|
845
|
+
out_dyn_size_ext = calc_conv_out_length(
|
|
846
|
+
in_length=in_length.dyn_size_ext,
|
|
847
|
+
filter_size=filter_size,
|
|
848
|
+
stride=stride,
|
|
849
|
+
dilation_rate=dilation_rate,
|
|
850
|
+
padding=padding,
|
|
851
|
+
)
|
|
852
|
+
assert isinstance(out_dyn_size_ext, Tensor)
|
|
853
|
+
out_length.dyn_size_ext = out_dyn_size_ext
|
|
854
|
+
|
|
855
|
+
return out_length
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def make_transposed_conv_out_spatial_dims(
|
|
859
|
+
in_spatial_dims: Sequence[Dim],
|
|
860
|
+
*,
|
|
861
|
+
filter_size: Union[Sequence[Union[int, Dim]], int, Dim],
|
|
862
|
+
padding: Union[str, int, Sequence[int]],
|
|
863
|
+
output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
|
|
864
|
+
strides: Union[Sequence[Optional[int]], None, int] = None,
|
|
865
|
+
dilation_rate: Union[Sequence[int], int] = 1,
|
|
866
|
+
description_prefix: Optional[str] = None,
|
|
867
|
+
) -> List[Dim]:
|
|
868
|
+
"""create out spatial dims from in spatial dims"""
|
|
869
|
+
nd = len(in_spatial_dims)
|
|
870
|
+
if isinstance(filter_size, (int, Dim)):
|
|
871
|
+
filter_size = [filter_size] * nd
|
|
872
|
+
filter_size = [d.dimension if isinstance(d, Dim) else d for d in filter_size]
|
|
873
|
+
assert all(isinstance(s, int) for s in filter_size)
|
|
874
|
+
if isinstance(strides, int) or strides is None:
|
|
875
|
+
strides = [strides] * nd
|
|
876
|
+
if isinstance(dilation_rate, int):
|
|
877
|
+
dilation_rate = [dilation_rate] * nd
|
|
878
|
+
if isinstance(padding, (int, str)):
|
|
879
|
+
padding = [padding] * nd
|
|
880
|
+
if isinstance(output_padding, int) or output_padding is None:
|
|
881
|
+
output_padding = [output_padding] * nd
|
|
882
|
+
assert (
|
|
883
|
+
nd
|
|
884
|
+
== len(in_spatial_dims)
|
|
885
|
+
== len(filter_size)
|
|
886
|
+
== len(strides)
|
|
887
|
+
== len(dilation_rate)
|
|
888
|
+
== len(padding)
|
|
889
|
+
== len(output_padding)
|
|
890
|
+
)
|
|
891
|
+
padding = [p.lower() if isinstance(p, str) else p for p in padding]
|
|
892
|
+
out_spatial_dims = []
|
|
893
|
+
for i in range(nd):
|
|
894
|
+
out_spatial_dims.append(
|
|
895
|
+
calc_transposed_conv_out_length(
|
|
896
|
+
in_spatial_dims[i],
|
|
897
|
+
filter_size=filter_size[i],
|
|
898
|
+
padding=padding[i],
|
|
899
|
+
stride=strides[i],
|
|
900
|
+
dilation_rate=dilation_rate[i],
|
|
901
|
+
name=f"{description_prefix}:spatial{i}" if description_prefix else None,
|
|
902
|
+
)
|
|
903
|
+
)
|
|
904
|
+
return out_spatial_dims
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def calc_transposed_conv_out_length(
|
|
908
|
+
in_length: Union[T, int, Dim, Tensor],
|
|
909
|
+
*,
|
|
910
|
+
filter_size: Union[int, Dim],
|
|
911
|
+
padding: Union[int, str],
|
|
912
|
+
output_padding: Optional[int] = None,
|
|
913
|
+
stride: Optional[int] = None,
|
|
914
|
+
dilation_rate: int = 1,
|
|
915
|
+
name: Optional[str] = None,
|
|
916
|
+
) -> T:
|
|
917
|
+
"""
|
|
918
|
+
Determines output length of a transposed convolution given input length.
|
|
919
|
+
|
|
920
|
+
Copied from TF/Keras conv_utils.deconv_output_length
|
|
921
|
+
(https://github.com/tensorflow/tensorflow/blob/5912f51d580551e5cee2cfde4cb882594b4d3e60/tensorflow/python/keras/utils/conv_utils.py#L140),
|
|
922
|
+
adapted with simplification.
|
|
923
|
+
|
|
924
|
+
Also see :func:`calc_conv_out_length`.
|
|
925
|
+
|
|
926
|
+
:param in_length:
|
|
927
|
+
:param filter_size:
|
|
928
|
+
:param padding: one of `"same"`, `"valid"`, `"full"`.
|
|
929
|
+
:param output_padding: amount of padding along the output dimension.
|
|
930
|
+
Can be set to `None` in which case the output length is inferred.
|
|
931
|
+
:param stride:
|
|
932
|
+
:param dilation_rate:
|
|
933
|
+
:param name:
|
|
934
|
+
:returns: The output length (integer)
|
|
935
|
+
"""
|
|
936
|
+
assert padding in {"same", "valid", "full"} or isinstance(padding, int)
|
|
937
|
+
|
|
938
|
+
if isinstance(filter_size, int):
|
|
939
|
+
filter_size_int = filter_size
|
|
940
|
+
elif isinstance(filter_size, Dim):
|
|
941
|
+
filter_size_int = filter_size.dimension
|
|
942
|
+
else:
|
|
943
|
+
filter_size_int = None
|
|
944
|
+
filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
|
|
945
|
+
|
|
946
|
+
# Get the dilated kernel size
|
|
947
|
+
if dilation_rate != 1 and filter_size_int != 1:
|
|
948
|
+
filter_size = filter_size + (filter_size_ - 1) * (dilation_rate - 1)
|
|
949
|
+
|
|
950
|
+
if stride is None:
|
|
951
|
+
assert filter_size_int is not None
|
|
952
|
+
stride = filter_size_int
|
|
953
|
+
if stride != 1:
|
|
954
|
+
in_length = in_length * stride
|
|
955
|
+
|
|
956
|
+
# Infer length if output padding is None, else compute the exact length
|
|
957
|
+
if output_padding is None:
|
|
958
|
+
if padding == "valid" or padding == 0:
|
|
959
|
+
if filter_size_int == stride:
|
|
960
|
+
out_length = in_length
|
|
961
|
+
elif filter_size_int is not None:
|
|
962
|
+
out_length = in_length + max(filter_size_int - stride, 0)
|
|
963
|
+
elif isinstance(filter_size, Tensor):
|
|
964
|
+
out_length = in_length + rf.relu(filter_size - stride)
|
|
965
|
+
elif isinstance(filter_size, Dim):
|
|
966
|
+
out_length = in_length + (filter_size - stride)
|
|
967
|
+
else:
|
|
968
|
+
raise ValueError(f"invalid filter_size {filter_size!r} type {type(filter_size)}")
|
|
969
|
+
elif padding == "full":
|
|
970
|
+
out_length = in_length - (stride + filter_size_ - 2)
|
|
971
|
+
elif padding == "same":
|
|
972
|
+
out_length = in_length
|
|
973
|
+
else:
|
|
974
|
+
raise ValueError(f"invalid padding {padding!r}")
|
|
975
|
+
|
|
976
|
+
else: # output_padding
|
|
977
|
+
if padding == "same":
|
|
978
|
+
pad = filter_size // 2
|
|
979
|
+
elif padding == "valid":
|
|
980
|
+
pad = 0
|
|
981
|
+
elif padding == "full":
|
|
982
|
+
pad = filter_size - 1
|
|
983
|
+
elif isinstance(padding, int):
|
|
984
|
+
pad = padding
|
|
985
|
+
else:
|
|
986
|
+
raise ValueError(f"invalid padding {padding!r}")
|
|
987
|
+
out_length = in_length + (filter_size - stride - 2 * pad + output_padding)
|
|
988
|
+
|
|
989
|
+
if isinstance(in_length, Dim):
|
|
990
|
+
assert isinstance(out_length, Dim)
|
|
991
|
+
if name and out_length != in_length:
|
|
992
|
+
out_length.name = name
|
|
993
|
+
if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
|
|
994
|
+
out_dyn_size_ext = calc_transposed_conv_out_length(
|
|
995
|
+
in_length=in_length.dyn_size_ext,
|
|
996
|
+
filter_size=filter_size,
|
|
997
|
+
padding=padding,
|
|
998
|
+
output_padding=output_padding,
|
|
999
|
+
stride=stride,
|
|
1000
|
+
dilation_rate=dilation_rate,
|
|
1001
|
+
)
|
|
1002
|
+
assert isinstance(out_dyn_size_ext, Tensor)
|
|
1003
|
+
out_length.dyn_size_ext = out_dyn_size_ext
|
|
1004
|
+
|
|
1005
|
+
return out_length
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def _ceildiv(a: T, b: Union[T, int, Tensor]) -> T:
|
|
1009
|
+
if isinstance(b, int) and b == 1:
|
|
1010
|
+
return a
|
|
1011
|
+
if isinstance(a, Tensor):
|
|
1012
|
+
return rf.ceil_divide(a, b)
|
|
1013
|
+
return -(-a // b)
|
|
1014
|
+
|
|
796
1015
|
|
|
797
1016
|
def _should_use_consistent_same_padding() -> bool:
|
|
798
1017
|
"""
|
returnn/frontend/device.py
CHANGED
|
@@ -8,7 +8,13 @@ from contextlib import contextmanager
|
|
|
8
8
|
from returnn.tensor import Tensor
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"copy_to_device",
|
|
13
|
+
"get_default_device",
|
|
14
|
+
"set_default_device",
|
|
15
|
+
"set_default_device_ctx",
|
|
16
|
+
"get_default_dim_size_device",
|
|
17
|
+
]
|
|
12
18
|
|
|
13
19
|
|
|
14
20
|
_default_device: Optional[str] = None
|
|
@@ -61,3 +67,10 @@ def set_default_device_ctx(device: Optional[str]):
|
|
|
61
67
|
yield
|
|
62
68
|
finally:
|
|
63
69
|
_default_device = old_device
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_default_dim_size_device() -> Optional[str]:
|
|
73
|
+
"""
|
|
74
|
+
:return: default device, where to put new tensors for dim sizes (Dim.dyn_size_ext)
|
|
75
|
+
"""
|
|
76
|
+
return "cpu"
|
|
@@ -167,6 +167,25 @@ class ConformerConvSubsample(ISeqDownsamplingEncoder):
|
|
|
167
167
|
out, _ = rf.merge_dims(x, dims=[self._final_second_spatial_dim, in_dim])
|
|
168
168
|
return out, in_spatial_dims[0]
|
|
169
169
|
|
|
170
|
+
def get_out_spatial_dim(self, in_spatial_dim: Dim) -> Dim:
|
|
171
|
+
"""Get output spatial dimension given input spatial dimension."""
|
|
172
|
+
out_spatial_dim = in_spatial_dim
|
|
173
|
+
for i, conv_layer in enumerate(self.conv_layers):
|
|
174
|
+
(out_spatial_dim,) = rf.make_conv_out_spatial_dims(
|
|
175
|
+
[out_spatial_dim],
|
|
176
|
+
filter_size=conv_layer.filter_size[0],
|
|
177
|
+
strides=conv_layer.strides[0],
|
|
178
|
+
padding=conv_layer.padding,
|
|
179
|
+
)
|
|
180
|
+
if self.pool_sizes and i < len(self.pool_sizes):
|
|
181
|
+
(out_spatial_dim,) = rf.make_conv_out_spatial_dims(
|
|
182
|
+
[out_spatial_dim],
|
|
183
|
+
filter_size=self.pool_sizes[i][0],
|
|
184
|
+
strides=self.pool_sizes[i][0],
|
|
185
|
+
padding="same",
|
|
186
|
+
)
|
|
187
|
+
return out_spatial_dim
|
|
188
|
+
|
|
170
189
|
|
|
171
190
|
class ConformerEncoderLayer(rf.Module):
|
|
172
191
|
"""
|
|
@@ -273,6 +292,7 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
273
292
|
x_mhsa = self.self_att(x_mhsa_ln, axis=spatial_dim)
|
|
274
293
|
x_mhsa = rf.dropout(x_mhsa, self.dropout, axis=self.dropout_broadcast and self.out_dim)
|
|
275
294
|
x_mhsa_out = x_mhsa + x_ffn1_out
|
|
295
|
+
del x_mhsa
|
|
276
296
|
|
|
277
297
|
# Conv
|
|
278
298
|
x_conv_ln = self.conv_layer_norm(x_mhsa_out)
|
|
@@ -79,6 +79,8 @@ class TransformerEncoder(rf.Module):
|
|
|
79
79
|
self.model_dim = model_dim
|
|
80
80
|
self.embed_dim = embed_dim
|
|
81
81
|
|
|
82
|
+
self.out_dim = self.model_dim # alias. consistency, compatibility
|
|
83
|
+
|
|
82
84
|
if input_embedding is None or isinstance(input_embedding, rf.Module):
|
|
83
85
|
pass
|
|
84
86
|
elif isinstance(input_embedding, type):
|