returnn 1.20251013.113026__py3-none-any.whl → 1.20260109.93428__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (43) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/config.py +1 -1
  4. returnn/datasets/distrib_files.py +53 -1
  5. returnn/datasets/generating.py +3 -5
  6. returnn/datasets/lm.py +20 -0
  7. returnn/datasets/meta.py +179 -60
  8. returnn/datasets/postprocessing.py +597 -108
  9. returnn/datasets/util/vocabulary.py +90 -0
  10. returnn/frontend/array_.py +46 -0
  11. returnn/frontend/attention.py +54 -20
  12. returnn/frontend/conv.py +273 -54
  13. returnn/frontend/device.py +14 -1
  14. returnn/frontend/encoder/conformer.py +20 -0
  15. returnn/frontend/encoder/transformer.py +2 -0
  16. returnn/frontend/loss.py +40 -1
  17. returnn/frontend/math_.py +54 -14
  18. returnn/frontend/module.py +8 -1
  19. returnn/frontend/nested.py +5 -0
  20. returnn/native_op.cpp +80 -0
  21. returnn/sprint/cache.py +12 -13
  22. returnn/tensor/_dim_extra.py +39 -24
  23. returnn/tensor/utils.py +7 -4
  24. returnn/tf/frontend_layers/_backend.py +4 -3
  25. returnn/tf/layers/basic.py +15 -39
  26. returnn/tf/native_op.py +11 -58
  27. returnn/tf/network.py +1 -1
  28. returnn/tf/util/basic.py +19 -0
  29. returnn/torch/engine.py +67 -2
  30. returnn/torch/frontend/_backend.py +135 -13
  31. returnn/torch/frontend/bridge.py +61 -0
  32. returnn/torch/util/exception_helper.py +7 -1
  33. returnn/util/basic.py +6 -7
  34. returnn/util/better_exchook.py +4 -0
  35. returnn/util/collect_outputs_dict.py +79 -0
  36. returnn/util/debug.py +11 -2
  37. returnn/util/file_cache.py +15 -1
  38. returnn/util/task_system.py +1 -1
  39. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/METADATA +2 -2
  40. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/RECORD +43 -42
  41. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/LICENSE +0 -0
  42. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/WHEEL +0 -0
  43. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/top_level.txt +0 -0
returnn/frontend/conv.py CHANGED
@@ -3,7 +3,7 @@ Convolution, transposed convolution, pooling
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Optional, Sequence, Tuple, Union
6
+ from typing import Optional, Union, TypeVar, Sequence, Tuple, List
7
7
  from returnn.util.basic import next_type_attrib_in_mro_chain
8
8
  from returnn.tensor import Tensor, Dim
9
9
  import returnn.frontend as rf
@@ -25,6 +25,9 @@ __all__ = [
25
25
  "pool2d",
26
26
  "pool3d",
27
27
  "make_conv_out_spatial_dims",
28
+ "calc_conv_out_length",
29
+ "make_transposed_conv_out_spatial_dims",
30
+ "calc_transposed_conv_out_length",
28
31
  ]
29
32
 
30
33
 
@@ -396,7 +399,11 @@ def transposed_conv(
396
399
  )
397
400
  if use_mask:
398
401
  source = source.copy_masked(0, dims=in_spatial_dims)
399
- if padding == "same" and _any_is_non_default(strides, default=1) and _should_use_consistent_same_padding():
402
+ if (
403
+ padding == "same"
404
+ and any(s != 1 for s in (strides or [fs.dimension for fs in filter_size]))
405
+ and _should_use_consistent_same_padding()
406
+ ):
400
407
  # I don't really know what this should mean here... Investigate this further...
401
408
  raise NotImplementedError("consistent same padding not implemented for transposed conv")
402
409
  # noinspection PyProtectedMember
@@ -424,6 +431,39 @@ class TransposedConv1d(_TransposedConv):
424
431
 
425
432
  nd = 1
426
433
 
434
+ def __init__(
435
+ self,
436
+ in_dim: Dim,
437
+ out_dim: Dim,
438
+ filter_size: Union[int, Dim],
439
+ *,
440
+ padding: str,
441
+ remove_padding: int = 0,
442
+ output_padding: Optional[int] = None,
443
+ strides: Optional[int] = None,
444
+ with_bias: bool = True,
445
+ ):
446
+ """
447
+ :param in_dim:
448
+ :param out_dim:
449
+ :param filter_size:
450
+ :param strides: specifies the upscaling. by default, same as filter_size
451
+ :param padding: "same" or "valid"
452
+ :param remove_padding:
453
+ :param output_padding:
454
+ :param with_bias: whether to add a bias. enabled by default
455
+ """
456
+ super().__init__(
457
+ in_dim=in_dim,
458
+ out_dim=out_dim,
459
+ filter_size=[filter_size],
460
+ padding=padding,
461
+ remove_padding=remove_padding,
462
+ output_padding=output_padding,
463
+ strides=[strides] if strides is not None else None,
464
+ with_bias=with_bias,
465
+ )
466
+
427
467
  __call__ = _ConvOrTransposedConv._call_nd1
428
468
 
429
469
 
@@ -704,7 +744,7 @@ def make_conv_out_spatial_dims(
704
744
  strides: Union[Sequence[int], int] = 1,
705
745
  dilation_rate: Union[Sequence[int], int] = 1,
706
746
  description_prefix: Optional[str] = None,
707
- ) -> Sequence[Dim]:
747
+ ) -> List[Dim]:
708
748
  """create out spatial dims from in spatial dims"""
709
749
  nd = len(in_spatial_dims)
710
750
  if isinstance(filter_size, (int, Dim)):
@@ -715,84 +755,263 @@ def make_conv_out_spatial_dims(
715
755
  strides = [strides] * nd
716
756
  if isinstance(dilation_rate, int):
717
757
  dilation_rate = [dilation_rate] * nd
718
- assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate)
719
758
  if isinstance(padding, (int, str)):
720
759
  padding = [padding] * nd
760
+ assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate) == len(padding)
721
761
  padding = [p.lower() if isinstance(p, str) else p for p in padding]
722
762
  out_spatial_dims = []
723
763
  for i in range(nd):
724
- in_spatial_dim = in_spatial_dims[i]
725
- if (filter_size[i] == strides[i] == 1 and padding[i] in ("valid", "same", 0)) or (
726
- strides[i] == 1 and padding[i] == "same"
727
- ):
728
- out_spatial_dims.append(in_spatial_dim)
729
- else:
730
- out_spatial_dim = _calc_out_dim(
731
- in_dim=in_spatial_dim,
764
+ out_spatial_dims.append(
765
+ calc_conv_out_length(
766
+ in_spatial_dims[i],
732
767
  filter_size=filter_size[i],
768
+ padding=padding[i],
733
769
  stride=strides[i],
734
770
  dilation_rate=dilation_rate[i],
735
- padding=padding[i],
771
+ name=f"{description_prefix}:spatial{i}" if description_prefix else None,
736
772
  )
737
- assert isinstance(out_spatial_dim, Dim)
738
- if description_prefix and out_spatial_dim != in_spatial_dim:
739
- out_spatial_dim.name = f"{description_prefix}:spatial{i}"
740
- if in_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext is None:
741
- out_spatial_dim.dyn_size_ext = _calc_out_dim(
742
- in_dim=in_spatial_dim.dyn_size_ext,
743
- filter_size=filter_size[i],
744
- stride=strides[i],
745
- dilation_rate=dilation_rate[i],
746
- padding=padding[i],
747
- )
748
- out_spatial_dims.append(out_spatial_dim)
773
+ )
749
774
  return out_spatial_dims
750
775
 
751
776
 
752
- def _calc_out_dim(in_dim, filter_size, stride, padding, dilation_rate=1):
777
+ T = TypeVar("T", int, Dim, Tensor)
778
+
779
+
780
+ def calc_conv_out_length(
781
+ in_length: Union[T, int, Dim, Tensor],
782
+ *,
783
+ filter_size: Union[T, int, Dim, Tensor],
784
+ stride: int,
785
+ padding: Union[str, int],
786
+ dilation_rate: int = 1,
787
+ name: Optional[str] = None,
788
+ ) -> T:
753
789
  """
754
790
  Copied and adapted from TF ConvLayer.calc_out_dim.
755
791
 
756
- :param T|int|Tensor|torch.Tensor|tensorflow.Tensor|Dim in_dim: dimension in some axis
757
- :param int filter_size: e.g. 2, for the corresponding axis
758
- :param int stride: e.g. 1, for the corresponding axis
759
- :param int dilation_rate: e.g. 1
760
- :param str|int padding: "valid" or "same" or int
792
+ :param T in_length: dimension in some axis
793
+ :param filter_size: e.g. 2, for the corresponding axis
794
+ :param stride: e.g. 1, for the corresponding axis
795
+ :param dilation_rate: e.g. 1
796
+ :param padding: "valid" or "same" or int
797
+ :param name:
761
798
  :return: the output dimension
762
- :rtype: T
763
799
  """
800
+ padding = padding.lower() if isinstance(padding, str) else padding
801
+ if isinstance(filter_size, int):
802
+ filter_size_int = filter_size
803
+ elif isinstance(filter_size, Dim):
804
+ filter_size_int = filter_size.dimension
805
+ else:
806
+ filter_size_int = None
807
+ filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
764
808
 
765
- def ceildiv(a, b):
766
- """
767
- :param T|int|Tensor|torch.Tensor|tensorflow.Tensor a:
768
- :param T|int|Tensor|torch.Tensor|tensorflow.Tensor b:
769
- :rtype: T
770
- """
771
- if isinstance(b, int) and b == 1:
772
- return a
773
- if isinstance(a, Tensor):
774
- return rf.ceil_divide(a, b)
775
- return -(-a // b)
809
+ if (filter_size_int == stride == 1 and padding in ("valid", "same", 0)) or (stride == 1 and padding == "same"):
810
+ return in_length
776
811
 
777
- padding = padding.lower() if isinstance(padding, str) else padding
778
812
  # See tf.compat.v1.nn.convolution() documentation for more.
779
813
  if padding == "same":
780
- if isinstance(in_dim, Dim):
781
- return in_dim.ceildiv_right(stride)
782
- return ceildiv(in_dim, stride)
814
+ if isinstance(in_length, Dim):
815
+ out_length = in_length.ceildiv_right(stride)
816
+ else:
817
+ out_length = _ceildiv(in_length, stride)
783
818
  elif padding == "valid" or isinstance(padding, int):
784
819
  if isinstance(padding, int) and padding != 0:
785
820
  assert padding > 0
786
- in_dim = padding + in_dim + padding
787
- if isinstance(in_dim, Dim):
788
- filter_left_dilated = (filter_size - 1) * dilation_rate // 2
789
- filter_right_dilated = (filter_size - 1) * dilation_rate - filter_left_dilated
790
- valid_part = in_dim.sub_left(filter_left_dilated).sub_right(filter_right_dilated)
791
- return valid_part.ceildiv_right(stride)
792
- return ceildiv(in_dim - (filter_size - 1) * dilation_rate, stride)
821
+ in_length = padding + in_length + padding
822
+
823
+ if filter_size_int == 1:
824
+ valid_part = in_length
825
+ elif isinstance(in_length, Dim):
826
+ filter_left_dilated = (filter_size_ - 1) * dilation_rate // 2
827
+ filter_right_dilated = (filter_size_ - 1) * dilation_rate - filter_left_dilated
828
+ valid_part = in_length.sub_left(filter_left_dilated).sub_right(filter_right_dilated)
829
+ else:
830
+ valid_part = in_length - (filter_size_ - 1) * dilation_rate
831
+
832
+ if isinstance(valid_part, Dim):
833
+ out_length = valid_part.ceildiv_right(stride)
834
+ else:
835
+ out_length = _ceildiv(valid_part, stride)
836
+
793
837
  else:
794
838
  raise ValueError(f"invalid padding {padding!r} (type {type(padding).__name__})")
795
839
 
840
+ if isinstance(in_length, Dim):
841
+ assert isinstance(out_length, Dim)
842
+ if name and out_length != in_length:
843
+ out_length.name = name
844
+ if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
845
+ out_dyn_size_ext = calc_conv_out_length(
846
+ in_length=in_length.dyn_size_ext,
847
+ filter_size=filter_size,
848
+ stride=stride,
849
+ dilation_rate=dilation_rate,
850
+ padding=padding,
851
+ )
852
+ assert isinstance(out_dyn_size_ext, Tensor)
853
+ out_length.dyn_size_ext = out_dyn_size_ext
854
+
855
+ return out_length
856
+
857
+
858
+ def make_transposed_conv_out_spatial_dims(
859
+ in_spatial_dims: Sequence[Dim],
860
+ *,
861
+ filter_size: Union[Sequence[Union[int, Dim]], int, Dim],
862
+ padding: Union[str, int, Sequence[int]],
863
+ output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
864
+ strides: Union[Sequence[Optional[int]], None, int] = None,
865
+ dilation_rate: Union[Sequence[int], int] = 1,
866
+ description_prefix: Optional[str] = None,
867
+ ) -> List[Dim]:
868
+ """create out spatial dims from in spatial dims"""
869
+ nd = len(in_spatial_dims)
870
+ if isinstance(filter_size, (int, Dim)):
871
+ filter_size = [filter_size] * nd
872
+ filter_size = [d.dimension if isinstance(d, Dim) else d for d in filter_size]
873
+ assert all(isinstance(s, int) for s in filter_size)
874
+ if isinstance(strides, int) or strides is None:
875
+ strides = [strides] * nd
876
+ if isinstance(dilation_rate, int):
877
+ dilation_rate = [dilation_rate] * nd
878
+ if isinstance(padding, (int, str)):
879
+ padding = [padding] * nd
880
+ if isinstance(output_padding, int) or output_padding is None:
881
+ output_padding = [output_padding] * nd
882
+ assert (
883
+ nd
884
+ == len(in_spatial_dims)
885
+ == len(filter_size)
886
+ == len(strides)
887
+ == len(dilation_rate)
888
+ == len(padding)
889
+ == len(output_padding)
890
+ )
891
+ padding = [p.lower() if isinstance(p, str) else p for p in padding]
892
+ out_spatial_dims = []
893
+ for i in range(nd):
894
+ out_spatial_dims.append(
895
+ calc_transposed_conv_out_length(
896
+ in_spatial_dims[i],
897
+ filter_size=filter_size[i],
898
+ padding=padding[i],
899
+ stride=strides[i],
900
+ dilation_rate=dilation_rate[i],
901
+ name=f"{description_prefix}:spatial{i}" if description_prefix else None,
902
+ )
903
+ )
904
+ return out_spatial_dims
905
+
906
+
907
+ def calc_transposed_conv_out_length(
908
+ in_length: Union[T, int, Dim, Tensor],
909
+ *,
910
+ filter_size: Union[int, Dim],
911
+ padding: Union[int, str],
912
+ output_padding: Optional[int] = None,
913
+ stride: Optional[int] = None,
914
+ dilation_rate: int = 1,
915
+ name: Optional[str] = None,
916
+ ) -> T:
917
+ """
918
+ Determines output length of a transposed convolution given input length.
919
+
920
+ Copied from TF/Keras conv_utils.deconv_output_length
921
+ (https://github.com/tensorflow/tensorflow/blob/5912f51d580551e5cee2cfde4cb882594b4d3e60/tensorflow/python/keras/utils/conv_utils.py#L140),
922
+ adapted with simplification.
923
+
924
+ Also see :func:`calc_conv_out_length`.
925
+
926
+ :param in_length:
927
+ :param filter_size:
928
+ :param padding: one of `"same"`, `"valid"`, `"full"`.
929
+ :param output_padding: amount of padding along the output dimension.
930
+ Can be set to `None` in which case the output length is inferred.
931
+ :param stride:
932
+ :param dilation_rate:
933
+ :param name:
934
+ :returns: The output length (integer)
935
+ """
936
+ assert padding in {"same", "valid", "full"} or isinstance(padding, int)
937
+
938
+ if isinstance(filter_size, int):
939
+ filter_size_int = filter_size
940
+ elif isinstance(filter_size, Dim):
941
+ filter_size_int = filter_size.dimension
942
+ else:
943
+ filter_size_int = None
944
+ filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
945
+
946
+ # Get the dilated kernel size
947
+ if dilation_rate != 1 and filter_size_int != 1:
948
+ filter_size = filter_size + (filter_size_ - 1) * (dilation_rate - 1)
949
+
950
+ if stride is None:
951
+ assert filter_size_int is not None
952
+ stride = filter_size_int
953
+ if stride != 1:
954
+ in_length = in_length * stride
955
+
956
+ # Infer length if output padding is None, else compute the exact length
957
+ if output_padding is None:
958
+ if padding == "valid" or padding == 0:
959
+ if filter_size_int == stride:
960
+ out_length = in_length
961
+ elif filter_size_int is not None:
962
+ out_length = in_length + max(filter_size_int - stride, 0)
963
+ elif isinstance(filter_size, Tensor):
964
+ out_length = in_length + rf.relu(filter_size - stride)
965
+ elif isinstance(filter_size, Dim):
966
+ out_length = in_length + (filter_size - stride)
967
+ else:
968
+ raise ValueError(f"invalid filter_size {filter_size!r} type {type(filter_size)}")
969
+ elif padding == "full":
970
+ out_length = in_length - (stride + filter_size_ - 2)
971
+ elif padding == "same":
972
+ out_length = in_length
973
+ else:
974
+ raise ValueError(f"invalid padding {padding!r}")
975
+
976
+ else: # output_padding
977
+ if padding == "same":
978
+ pad = filter_size // 2
979
+ elif padding == "valid":
980
+ pad = 0
981
+ elif padding == "full":
982
+ pad = filter_size - 1
983
+ elif isinstance(padding, int):
984
+ pad = padding
985
+ else:
986
+ raise ValueError(f"invalid padding {padding!r}")
987
+ out_length = in_length + (filter_size - stride - 2 * pad + output_padding)
988
+
989
+ if isinstance(in_length, Dim):
990
+ assert isinstance(out_length, Dim)
991
+ if name and out_length != in_length:
992
+ out_length.name = name
993
+ if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
994
+ out_dyn_size_ext = calc_transposed_conv_out_length(
995
+ in_length=in_length.dyn_size_ext,
996
+ filter_size=filter_size,
997
+ padding=padding,
998
+ output_padding=output_padding,
999
+ stride=stride,
1000
+ dilation_rate=dilation_rate,
1001
+ )
1002
+ assert isinstance(out_dyn_size_ext, Tensor)
1003
+ out_length.dyn_size_ext = out_dyn_size_ext
1004
+
1005
+ return out_length
1006
+
1007
+
1008
+ def _ceildiv(a: T, b: Union[T, int, Tensor]) -> T:
1009
+ if isinstance(b, int) and b == 1:
1010
+ return a
1011
+ if isinstance(a, Tensor):
1012
+ return rf.ceil_divide(a, b)
1013
+ return -(-a // b)
1014
+
796
1015
 
797
1016
  def _should_use_consistent_same_padding() -> bool:
798
1017
  """
@@ -8,7 +8,13 @@ from contextlib import contextmanager
8
8
  from returnn.tensor import Tensor
9
9
 
10
10
 
11
- __all__ = ["copy_to_device", "get_default_device", "set_default_device", "set_default_device_ctx"]
11
+ __all__ = [
12
+ "copy_to_device",
13
+ "get_default_device",
14
+ "set_default_device",
15
+ "set_default_device_ctx",
16
+ "get_default_dim_size_device",
17
+ ]
12
18
 
13
19
 
14
20
  _default_device: Optional[str] = None
@@ -61,3 +67,10 @@ def set_default_device_ctx(device: Optional[str]):
61
67
  yield
62
68
  finally:
63
69
  _default_device = old_device
70
+
71
+
72
+ def get_default_dim_size_device() -> Optional[str]:
73
+ """
74
+ :return: default device, where to put new tensors for dim sizes (Dim.dyn_size_ext)
75
+ """
76
+ return "cpu"
@@ -167,6 +167,25 @@ class ConformerConvSubsample(ISeqDownsamplingEncoder):
167
167
  out, _ = rf.merge_dims(x, dims=[self._final_second_spatial_dim, in_dim])
168
168
  return out, in_spatial_dims[0]
169
169
 
170
+ def get_out_spatial_dim(self, in_spatial_dim: Dim) -> Dim:
171
+ """Get output spatial dimension given input spatial dimension."""
172
+ out_spatial_dim = in_spatial_dim
173
+ for i, conv_layer in enumerate(self.conv_layers):
174
+ (out_spatial_dim,) = rf.make_conv_out_spatial_dims(
175
+ [out_spatial_dim],
176
+ filter_size=conv_layer.filter_size[0],
177
+ strides=conv_layer.strides[0],
178
+ padding=conv_layer.padding,
179
+ )
180
+ if self.pool_sizes and i < len(self.pool_sizes):
181
+ (out_spatial_dim,) = rf.make_conv_out_spatial_dims(
182
+ [out_spatial_dim],
183
+ filter_size=self.pool_sizes[i][0],
184
+ strides=self.pool_sizes[i][0],
185
+ padding="same",
186
+ )
187
+ return out_spatial_dim
188
+
170
189
 
171
190
  class ConformerEncoderLayer(rf.Module):
172
191
  """
@@ -273,6 +292,7 @@ class ConformerEncoderLayer(rf.Module):
273
292
  x_mhsa = self.self_att(x_mhsa_ln, axis=spatial_dim)
274
293
  x_mhsa = rf.dropout(x_mhsa, self.dropout, axis=self.dropout_broadcast and self.out_dim)
275
294
  x_mhsa_out = x_mhsa + x_ffn1_out
295
+ del x_mhsa
276
296
 
277
297
  # Conv
278
298
  x_conv_ln = self.conv_layer_norm(x_mhsa_out)
@@ -79,6 +79,8 @@ class TransformerEncoder(rf.Module):
79
79
  self.model_dim = model_dim
80
80
  self.embed_dim = embed_dim
81
81
 
82
+ self.out_dim = self.model_dim # alias. consistency, compatibility
83
+
82
84
  if input_embedding is None or isinstance(input_embedding, rf.Module):
83
85
  pass
84
86
  elif isinstance(input_embedding, type):
returnn/frontend/loss.py CHANGED
@@ -3,11 +3,12 @@ Loss functions
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+ from typing import Optional, Tuple
6
7
  from returnn.tensor import Tensor, Dim
7
8
  import returnn.frontend as rf
8
9
 
9
10
 
10
- __all__ = ["cross_entropy", "ctc_loss", "edit_distance"]
11
+ __all__ = ["cross_entropy", "ctc_loss", "ctc_greedy_decode", "edit_distance"]
11
12
 
12
13
 
13
14
  def cross_entropy(
@@ -93,6 +94,44 @@ def ctc_loss(
93
94
  )
94
95
 
95
96
 
97
+ def ctc_greedy_decode(
98
+ logits: Tensor,
99
+ *,
100
+ in_spatial_dim: Dim,
101
+ blank_index: int,
102
+ out_spatial_dim: Optional[Dim] = None,
103
+ target_dim: Optional[Dim] = None,
104
+ wb_target_dim: Optional[Dim] = None,
105
+ ) -> Tuple[Tensor, Dim]:
106
+ """
107
+ Greedy CTC decode.
108
+
109
+ :return: (labels, out_spatial_dim)
110
+ """
111
+ if wb_target_dim is None:
112
+ assert logits.feature_dim
113
+ wb_target_dim = logits.feature_dim
114
+
115
+ labels = rf.reduce_argmax(logits, axis=wb_target_dim)
116
+ labels = rf.cast(labels, "int32")
117
+
118
+ labels_shifted = rf.shift_right(labels, axis=in_spatial_dim, pad_value=blank_index)
119
+ mask_repeat = labels != labels_shifted
120
+ labels, out_spatial_dim = rf.masked_select(
121
+ labels,
122
+ mask=(labels != blank_index) & mask_repeat,
123
+ dims=[in_spatial_dim],
124
+ out_dim=out_spatial_dim,
125
+ )
126
+
127
+ if target_dim:
128
+ # Set correct sparse_dim. Only currently implemented if blank comes after.
129
+ assert target_dim.dimension == blank_index
130
+ labels.sparse_dim = target_dim
131
+
132
+ return labels, out_spatial_dim
133
+
134
+
96
135
  def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, *, dtype: str = "int32") -> Tensor:
97
136
  """
98
137
  :param a: [B,Ta]
returnn/frontend/math_.py CHANGED
@@ -3,7 +3,6 @@ Math ops
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- import typing
7
6
  from typing import Optional, Sequence, Union, Tuple, overload
8
7
  import numpy
9
8
  from returnn.tensor import Tensor, Dim
@@ -77,7 +76,7 @@ __all__ = [
77
76
  ]
78
77
 
79
78
 
80
- @typing.overload
79
+ @overload
81
80
  def compare(
82
81
  a: Tensor,
83
82
  kind: str,
@@ -86,7 +85,19 @@ def compare(
86
85
  allow_broadcast_all_sources: Optional[bool] = None,
87
86
  dim_order: Optional[Sequence[Dim]] = None,
88
87
  ) -> Tensor:
89
- """compare with two tensors"""
88
+ """compare"""
89
+
90
+
91
+ @overload
92
+ def compare(
93
+ a: Union[Tensor, _RawTensorTypes],
94
+ kind: str,
95
+ b: Union[Tensor, _RawTensorTypes],
96
+ *,
97
+ allow_broadcast_all_sources: Optional[bool] = None,
98
+ dim_order: Optional[Sequence[Dim]] = None,
99
+ ) -> Tensor:
100
+ """compare"""
90
101
 
91
102
 
92
103
  _CompareMap = {
@@ -138,7 +149,7 @@ def compare_bc(
138
149
  return compare(a, kind, b, allow_broadcast_all_sources=True, dim_order=dim_order)
139
150
 
140
151
 
141
- @typing.overload
152
+ @overload
142
153
  def combine(
143
154
  a: Tensor,
144
155
  kind: str,
@@ -147,7 +158,19 @@ def combine(
147
158
  allow_broadcast_all_sources: Optional[bool] = None,
148
159
  dim_order: Optional[Sequence[Dim]] = None,
149
160
  ) -> Tensor:
150
- """combine with two tensors"""
161
+ """combine"""
162
+
163
+
164
+ @overload
165
+ def combine(
166
+ a: Union[Tensor, _RawTensorTypes],
167
+ kind: str,
168
+ b: Union[Tensor, _RawTensorTypes],
169
+ *,
170
+ allow_broadcast_all_sources: Optional[bool] = None,
171
+ dim_order: Optional[Sequence[Dim]] = None,
172
+ ) -> Union[Tensor, _RawTensorTypes]:
173
+ """combine"""
151
174
 
152
175
 
153
176
  _CombineMap = {
@@ -332,7 +355,12 @@ def logical_not(a: Tensor) -> Tensor:
332
355
 
333
356
  @overload
334
357
  def opt_logical_or(a: bool, b: bool) -> bool:
335
- """logical or"""
358
+ """opt logical or"""
359
+
360
+
361
+ @overload
362
+ def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
363
+ """opt logical or"""
336
364
 
337
365
 
338
366
  def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
@@ -350,7 +378,12 @@ def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tens
350
378
 
351
379
  @overload
352
380
  def opt_logical_and(a: bool, b: bool) -> bool:
353
- """logical and"""
381
+ """opt logical and"""
382
+
383
+
384
+ @overload
385
+ def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
386
+ """opt logical and"""
354
387
 
355
388
 
356
389
  def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
@@ -416,16 +449,23 @@ def minimum(a: Tensor, b: Union[Tensor, _RawTensorTypes], *other_tensors) -> Ten
416
449
 
417
450
  def clip_by_value(
418
451
  x: Tensor,
419
- clip_value_min: Union[Tensor, _RawTensorTypes],
420
- clip_value_max: Union[Tensor, _RawTensorTypes],
452
+ clip_value_min: Union[None, Tensor, _RawTensorTypes] = None,
453
+ clip_value_max: Union[None, Tensor, _RawTensorTypes] = None,
421
454
  *,
422
455
  allow_broadcast_all_sources: bool = False,
423
456
  ) -> Tensor:
424
457
  """clip by value"""
425
- # noinspection PyProtectedMember
426
- return x._raw_backend.clip_by_value(
427
- x, clip_value_min, clip_value_max, allow_broadcast_all_sources=allow_broadcast_all_sources
428
- )
458
+ if clip_value_min is not None and clip_value_max is not None:
459
+ # noinspection PyProtectedMember
460
+ return x._raw_backend.clip_by_value(
461
+ x, clip_value_min, clip_value_max, allow_broadcast_all_sources=allow_broadcast_all_sources
462
+ )
463
+ elif clip_value_min is not None and clip_value_max is None:
464
+ return maximum(x, clip_value_min)
465
+ elif clip_value_min is None and clip_value_max is not None:
466
+ return minimum(x, clip_value_max)
467
+ else:
468
+ return x
429
469
 
430
470
 
431
471
  def identity(x: Tensor) -> Tensor:
@@ -541,7 +581,7 @@ def floor(a: Tensor) -> Tensor:
541
581
 
542
582
  # noinspection PyShadowingBuiltins
543
583
  def round(a: Tensor) -> Tensor:
544
- """round"""
584
+ """round. the result dtype is same as input dtype, still float"""
545
585
  # noinspection PyProtectedMember
546
586
  return a._raw_backend.activation(a, "round")
547
587
 
@@ -274,10 +274,17 @@ class Functional(Module):
274
274
  (This is often not necessary, but sometimes useful.)
275
275
  """
276
276
 
277
- def __init__(self, func):
277
+ def __init__(self, func, *, attribs: Optional[Dict[str, Any]] = None):
278
+ """
279
+ :param func: callable. you might want to use functools.partial if you want to fix some arguments.
280
+ :param attribs: optional dict of attributes to set on this module. e.g. ``out_dim``.
281
+ """
278
282
  super().__init__()
279
283
  assert callable(func)
280
284
  self.func = func
285
+ if attribs:
286
+ for k, v in attribs.items():
287
+ setattr(self, k, v)
281
288
 
282
289
  def __repr__(self):
283
290
  return f"{self.__class__.__name__}({self.func.__qualname__})"