returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/config.py +1 -1
  4. returnn/datasets/basic.py +29 -13
  5. returnn/datasets/distrib_files.py +61 -3
  6. returnn/datasets/generating.py +12 -21
  7. returnn/datasets/huggingface.py +434 -0
  8. returnn/datasets/lm.py +20 -0
  9. returnn/datasets/meta.py +179 -60
  10. returnn/datasets/multi_proc.py +1 -1
  11. returnn/datasets/postprocessing.py +597 -108
  12. returnn/datasets/text_dict.py +1 -1
  13. returnn/datasets/util/vocabulary.py +90 -0
  14. returnn/frontend/_backend.py +7 -0
  15. returnn/frontend/array_.py +54 -1
  16. returnn/frontend/attention.py +54 -20
  17. returnn/frontend/conv.py +273 -54
  18. returnn/frontend/decoder/transformer.py +36 -17
  19. returnn/frontend/encoder/conformer.py +1 -0
  20. returnn/frontend/encoder/transformer.py +2 -0
  21. returnn/frontend/loss.py +40 -1
  22. returnn/frontend/module.py +8 -1
  23. returnn/frontend/nested.py +9 -0
  24. returnn/native_op.cpp +80 -0
  25. returnn/sprint/cache.py +12 -13
  26. returnn/tensor/_dim_extra.py +51 -29
  27. returnn/tensor/_tensor_extra.py +6 -1
  28. returnn/tensor/utils.py +7 -4
  29. returnn/tf/frontend_layers/_backend.py +11 -2
  30. returnn/tf/frontend_low_level/_backend.py +15 -0
  31. returnn/tf/layers/basic.py +16 -38
  32. returnn/tf/native_op.py +11 -58
  33. returnn/tf/network.py +1 -1
  34. returnn/tf/util/basic.py +19 -0
  35. returnn/torch/data/returnn_dataset_wrapper.py +9 -3
  36. returnn/torch/engine.py +67 -2
  37. returnn/torch/frontend/_backend.py +119 -7
  38. returnn/torch/util/diagnose_gpu.py +65 -31
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/util/basic.py +6 -7
  41. returnn/util/better_exchook.py +4 -0
  42. returnn/util/collect_outputs_dict.py +79 -0
  43. returnn/util/debug.py +11 -2
  44. returnn/util/file_cache.py +42 -4
  45. returnn/util/task_system.py +1 -1
  46. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
  47. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
  48. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
  49. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
  50. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
returnn/frontend/conv.py CHANGED
@@ -3,7 +3,7 @@ Convolution, transposed convolution, pooling
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Optional, Sequence, Tuple, Union
6
+ from typing import Optional, Union, TypeVar, Sequence, Tuple, List
7
7
  from returnn.util.basic import next_type_attrib_in_mro_chain
8
8
  from returnn.tensor import Tensor, Dim
9
9
  import returnn.frontend as rf
@@ -25,6 +25,9 @@ __all__ = [
25
25
  "pool2d",
26
26
  "pool3d",
27
27
  "make_conv_out_spatial_dims",
28
+ "calc_conv_out_length",
29
+ "make_transposed_conv_out_spatial_dims",
30
+ "calc_transposed_conv_out_length",
28
31
  ]
29
32
 
30
33
 
@@ -396,7 +399,11 @@ def transposed_conv(
396
399
  )
397
400
  if use_mask:
398
401
  source = source.copy_masked(0, dims=in_spatial_dims)
399
- if padding == "same" and _any_is_non_default(strides, default=1) and _should_use_consistent_same_padding():
402
+ if (
403
+ padding == "same"
404
+ and any(s != 1 for s in (strides or [fs.dimension for fs in filter_size]))
405
+ and _should_use_consistent_same_padding()
406
+ ):
400
407
  # I don't really know what this should mean here... Investigate this further...
401
408
  raise NotImplementedError("consistent same padding not implemented for transposed conv")
402
409
  # noinspection PyProtectedMember
@@ -424,6 +431,39 @@ class TransposedConv1d(_TransposedConv):
424
431
 
425
432
  nd = 1
426
433
 
434
+ def __init__(
435
+ self,
436
+ in_dim: Dim,
437
+ out_dim: Dim,
438
+ filter_size: Union[int, Dim],
439
+ *,
440
+ padding: str,
441
+ remove_padding: int = 0,
442
+ output_padding: Optional[int] = None,
443
+ strides: Optional[int] = None,
444
+ with_bias: bool = True,
445
+ ):
446
+ """
447
+ :param in_dim:
448
+ :param out_dim:
449
+ :param filter_size:
450
+ :param strides: specifies the upscaling. by default, same as filter_size
451
+ :param padding: "same" or "valid"
452
+ :param remove_padding:
453
+ :param output_padding:
454
+ :param with_bias: whether to add a bias. enabled by default
455
+ """
456
+ super().__init__(
457
+ in_dim=in_dim,
458
+ out_dim=out_dim,
459
+ filter_size=[filter_size],
460
+ padding=padding,
461
+ remove_padding=remove_padding,
462
+ output_padding=output_padding,
463
+ strides=[strides] if strides is not None else None,
464
+ with_bias=with_bias,
465
+ )
466
+
427
467
  __call__ = _ConvOrTransposedConv._call_nd1
428
468
 
429
469
 
@@ -704,7 +744,7 @@ def make_conv_out_spatial_dims(
704
744
  strides: Union[Sequence[int], int] = 1,
705
745
  dilation_rate: Union[Sequence[int], int] = 1,
706
746
  description_prefix: Optional[str] = None,
707
- ) -> Sequence[Dim]:
747
+ ) -> List[Dim]:
708
748
  """create out spatial dims from in spatial dims"""
709
749
  nd = len(in_spatial_dims)
710
750
  if isinstance(filter_size, (int, Dim)):
@@ -715,84 +755,263 @@ def make_conv_out_spatial_dims(
715
755
  strides = [strides] * nd
716
756
  if isinstance(dilation_rate, int):
717
757
  dilation_rate = [dilation_rate] * nd
718
- assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate)
719
758
  if isinstance(padding, (int, str)):
720
759
  padding = [padding] * nd
760
+ assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate) == len(padding)
721
761
  padding = [p.lower() if isinstance(p, str) else p for p in padding]
722
762
  out_spatial_dims = []
723
763
  for i in range(nd):
724
- in_spatial_dim = in_spatial_dims[i]
725
- if (filter_size[i] == strides[i] == 1 and padding[i] in ("valid", "same", 0)) or (
726
- strides[i] == 1 and padding[i] == "same"
727
- ):
728
- out_spatial_dims.append(in_spatial_dim)
729
- else:
730
- out_spatial_dim = _calc_out_dim(
731
- in_dim=in_spatial_dim,
764
+ out_spatial_dims.append(
765
+ calc_conv_out_length(
766
+ in_spatial_dims[i],
732
767
  filter_size=filter_size[i],
768
+ padding=padding[i],
733
769
  stride=strides[i],
734
770
  dilation_rate=dilation_rate[i],
735
- padding=padding[i],
771
+ name=f"{description_prefix}:spatial{i}" if description_prefix else None,
736
772
  )
737
- assert isinstance(out_spatial_dim, Dim)
738
- if description_prefix and out_spatial_dim != in_spatial_dim:
739
- out_spatial_dim.name = f"{description_prefix}:spatial{i}"
740
- if in_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext is None:
741
- out_spatial_dim.dyn_size_ext = _calc_out_dim(
742
- in_dim=in_spatial_dim.dyn_size_ext,
743
- filter_size=filter_size[i],
744
- stride=strides[i],
745
- dilation_rate=dilation_rate[i],
746
- padding=padding[i],
747
- )
748
- out_spatial_dims.append(out_spatial_dim)
773
+ )
749
774
  return out_spatial_dims
750
775
 
751
776
 
752
- def _calc_out_dim(in_dim, filter_size, stride, padding, dilation_rate=1):
777
+ T = TypeVar("T", int, Dim, Tensor)
778
+
779
+
780
+ def calc_conv_out_length(
781
+ in_length: Union[T, int, Dim, Tensor],
782
+ *,
783
+ filter_size: Union[T, int, Dim, Tensor],
784
+ stride: int,
785
+ padding: Union[str, int],
786
+ dilation_rate: int = 1,
787
+ name: Optional[str] = None,
788
+ ) -> T:
753
789
  """
754
790
  Copied and adapted from TF ConvLayer.calc_out_dim.
755
791
 
756
- :param T|int|Tensor|torch.Tensor|tensorflow.Tensor|Dim in_dim: dimension in some axis
757
- :param int filter_size: e.g. 2, for the corresponding axis
758
- :param int stride: e.g. 1, for the corresponding axis
759
- :param int dilation_rate: e.g. 1
760
- :param str|int padding: "valid" or "same" or int
792
+ :param T in_length: dimension in some axis
793
+ :param filter_size: e.g. 2, for the corresponding axis
794
+ :param stride: e.g. 1, for the corresponding axis
795
+ :param dilation_rate: e.g. 1
796
+ :param padding: "valid" or "same" or int
797
+ :param name:
761
798
  :return: the output dimension
762
- :rtype: T
763
799
  """
800
+ padding = padding.lower() if isinstance(padding, str) else padding
801
+ if isinstance(filter_size, int):
802
+ filter_size_int = filter_size
803
+ elif isinstance(filter_size, Dim):
804
+ filter_size_int = filter_size.dimension
805
+ else:
806
+ filter_size_int = None
807
+ filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
764
808
 
765
- def ceildiv(a, b):
766
- """
767
- :param T|int|Tensor|torch.Tensor|tensorflow.Tensor a:
768
- :param T|int|Tensor|torch.Tensor|tensorflow.Tensor b:
769
- :rtype: T
770
- """
771
- if isinstance(b, int) and b == 1:
772
- return a
773
- if isinstance(a, Tensor):
774
- return rf.ceil_divide(a, b)
775
- return -(-a // b)
809
+ if (filter_size_int == stride == 1 and padding in ("valid", "same", 0)) or (stride == 1 and padding == "same"):
810
+ return in_length
776
811
 
777
- padding = padding.lower() if isinstance(padding, str) else padding
778
812
  # See tf.compat.v1.nn.convolution() documentation for more.
779
813
  if padding == "same":
780
- if isinstance(in_dim, Dim):
781
- return in_dim.ceildiv_right(stride)
782
- return ceildiv(in_dim, stride)
814
+ if isinstance(in_length, Dim):
815
+ out_length = in_length.ceildiv_right(stride)
816
+ else:
817
+ out_length = _ceildiv(in_length, stride)
783
818
  elif padding == "valid" or isinstance(padding, int):
784
819
  if isinstance(padding, int) and padding != 0:
785
820
  assert padding > 0
786
- in_dim = padding + in_dim + padding
787
- if isinstance(in_dim, Dim):
788
- filter_left_dilated = (filter_size - 1) * dilation_rate // 2
789
- filter_right_dilated = (filter_size - 1) * dilation_rate - filter_left_dilated
790
- valid_part = in_dim.sub_left(filter_left_dilated).sub_right(filter_right_dilated)
791
- return valid_part.ceildiv_right(stride)
792
- return ceildiv(in_dim - (filter_size - 1) * dilation_rate, stride)
821
+ in_length = padding + in_length + padding
822
+
823
+ if filter_size_int == 1:
824
+ valid_part = in_length
825
+ elif isinstance(in_length, Dim):
826
+ filter_left_dilated = (filter_size_ - 1) * dilation_rate // 2
827
+ filter_right_dilated = (filter_size_ - 1) * dilation_rate - filter_left_dilated
828
+ valid_part = in_length.sub_left(filter_left_dilated).sub_right(filter_right_dilated)
829
+ else:
830
+ valid_part = in_length - (filter_size_ - 1) * dilation_rate
831
+
832
+ if isinstance(valid_part, Dim):
833
+ out_length = valid_part.ceildiv_right(stride)
834
+ else:
835
+ out_length = _ceildiv(valid_part, stride)
836
+
793
837
  else:
794
838
  raise ValueError(f"invalid padding {padding!r} (type {type(padding).__name__})")
795
839
 
840
+ if isinstance(in_length, Dim):
841
+ assert isinstance(out_length, Dim)
842
+ if name and out_length != in_length:
843
+ out_length.name = name
844
+ if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
845
+ out_dyn_size_ext = calc_conv_out_length(
846
+ in_length=in_length.dyn_size_ext,
847
+ filter_size=filter_size,
848
+ stride=stride,
849
+ dilation_rate=dilation_rate,
850
+ padding=padding,
851
+ )
852
+ assert isinstance(out_dyn_size_ext, Tensor)
853
+ out_length.dyn_size_ext = out_dyn_size_ext
854
+
855
+ return out_length
856
+
857
+
858
+ def make_transposed_conv_out_spatial_dims(
859
+ in_spatial_dims: Sequence[Dim],
860
+ *,
861
+ filter_size: Union[Sequence[Union[int, Dim]], int, Dim],
862
+ padding: Union[str, int, Sequence[int]],
863
+ output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
864
+ strides: Union[Sequence[Optional[int]], None, int] = None,
865
+ dilation_rate: Union[Sequence[int], int] = 1,
866
+ description_prefix: Optional[str] = None,
867
+ ) -> List[Dim]:
868
+ """create out spatial dims from in spatial dims"""
869
+ nd = len(in_spatial_dims)
870
+ if isinstance(filter_size, (int, Dim)):
871
+ filter_size = [filter_size] * nd
872
+ filter_size = [d.dimension if isinstance(d, Dim) else d for d in filter_size]
873
+ assert all(isinstance(s, int) for s in filter_size)
874
+ if isinstance(strides, int) or strides is None:
875
+ strides = [strides] * nd
876
+ if isinstance(dilation_rate, int):
877
+ dilation_rate = [dilation_rate] * nd
878
+ if isinstance(padding, (int, str)):
879
+ padding = [padding] * nd
880
+ if isinstance(output_padding, int) or output_padding is None:
881
+ output_padding = [output_padding] * nd
882
+ assert (
883
+ nd
884
+ == len(in_spatial_dims)
885
+ == len(filter_size)
886
+ == len(strides)
887
+ == len(dilation_rate)
888
+ == len(padding)
889
+ == len(output_padding)
890
+ )
891
+ padding = [p.lower() if isinstance(p, str) else p for p in padding]
892
+ out_spatial_dims = []
893
+ for i in range(nd):
894
+ out_spatial_dims.append(
895
+ calc_transposed_conv_out_length(
896
+ in_spatial_dims[i],
897
+ filter_size=filter_size[i],
898
+ padding=padding[i],
899
+ stride=strides[i],
900
+ dilation_rate=dilation_rate[i],
901
+ name=f"{description_prefix}:spatial{i}" if description_prefix else None,
902
+ )
903
+ )
904
+ return out_spatial_dims
905
+
906
+
907
+ def calc_transposed_conv_out_length(
908
+ in_length: Union[T, int, Dim, Tensor],
909
+ *,
910
+ filter_size: Union[int, Dim],
911
+ padding: Union[int, str],
912
+ output_padding: Optional[int] = None,
913
+ stride: Optional[int] = None,
914
+ dilation_rate: int = 1,
915
+ name: Optional[str] = None,
916
+ ) -> T:
917
+ """
918
+ Determines output length of a transposed convolution given input length.
919
+
920
+ Copied from TF/Keras conv_utils.deconv_output_length
921
+ (https://github.com/tensorflow/tensorflow/blob/5912f51d580551e5cee2cfde4cb882594b4d3e60/tensorflow/python/keras/utils/conv_utils.py#L140),
922
+ adapted with simplification.
923
+
924
+ Also see :func:`calc_conv_out_length`.
925
+
926
+ :param in_length:
927
+ :param filter_size:
928
+ :param padding: one of `"same"`, `"valid"`, `"full"`.
929
+ :param output_padding: amount of padding along the output dimension.
930
+ Can be set to `None` in which case the output length is inferred.
931
+ :param stride:
932
+ :param dilation_rate:
933
+ :param name:
934
+ :returns: The output length (integer)
935
+ """
936
+ assert padding in {"same", "valid", "full"} or isinstance(padding, int)
937
+
938
+ if isinstance(filter_size, int):
939
+ filter_size_int = filter_size
940
+ elif isinstance(filter_size, Dim):
941
+ filter_size_int = filter_size.dimension
942
+ else:
943
+ filter_size_int = None
944
+ filter_size_ = filter_size_int if isinstance(filter_size_int, int) else filter_size
945
+
946
+ # Get the dilated kernel size
947
+ if dilation_rate != 1 and filter_size_int != 1:
948
+ filter_size = filter_size + (filter_size_ - 1) * (dilation_rate - 1)
949
+
950
+ if stride is None:
951
+ assert filter_size_int is not None
952
+ stride = filter_size_int
953
+ if stride != 1:
954
+ in_length = in_length * stride
955
+
956
+ # Infer length if output padding is None, else compute the exact length
957
+ if output_padding is None:
958
+ if padding == "valid" or padding == 0:
959
+ if filter_size_int == stride:
960
+ out_length = in_length
961
+ elif filter_size_int is not None:
962
+ out_length = in_length + max(filter_size_int - stride, 0)
963
+ elif isinstance(filter_size, Tensor):
964
+ out_length = in_length + rf.relu(filter_size - stride)
965
+ elif isinstance(filter_size, Dim):
966
+ out_length = in_length + (filter_size - stride)
967
+ else:
968
+ raise ValueError(f"invalid filter_size {filter_size!r} type {type(filter_size)}")
969
+ elif padding == "full":
970
+ out_length = in_length - (stride + filter_size_ - 2)
971
+ elif padding == "same":
972
+ out_length = in_length
973
+ else:
974
+ raise ValueError(f"invalid padding {padding!r}")
975
+
976
+ else: # output_padding
977
+ if padding == "same":
978
+ pad = filter_size // 2
979
+ elif padding == "valid":
980
+ pad = 0
981
+ elif padding == "full":
982
+ pad = filter_size - 1
983
+ elif isinstance(padding, int):
984
+ pad = padding
985
+ else:
986
+ raise ValueError(f"invalid padding {padding!r}")
987
+ out_length = in_length + (filter_size - stride - 2 * pad + output_padding)
988
+
989
+ if isinstance(in_length, Dim):
990
+ assert isinstance(out_length, Dim)
991
+ if name and out_length != in_length:
992
+ out_length.name = name
993
+ if in_length.dyn_size_ext is not None and out_length.dyn_size_ext is None:
994
+ out_dyn_size_ext = calc_transposed_conv_out_length(
995
+ in_length=in_length.dyn_size_ext,
996
+ filter_size=filter_size,
997
+ padding=padding,
998
+ output_padding=output_padding,
999
+ stride=stride,
1000
+ dilation_rate=dilation_rate,
1001
+ )
1002
+ assert isinstance(out_dyn_size_ext, Tensor)
1003
+ out_length.dyn_size_ext = out_dyn_size_ext
1004
+
1005
+ return out_length
1006
+
1007
+
1008
+ def _ceildiv(a: T, b: Union[T, int, Tensor]) -> T:
1009
+ if isinstance(b, int) and b == 1:
1010
+ return a
1011
+ if isinstance(a, Tensor):
1012
+ return rf.ceil_divide(a, b)
1013
+ return -(-a // b)
1014
+
796
1015
 
797
1016
  def _should_use_consistent_same_padding() -> bool:
798
1017
  """
@@ -49,6 +49,7 @@ class TransformerDecoder(rf.Module):
49
49
  layer_opts: Optional[Dict[str, Any]] = None,
50
50
  embed_dim: Optional[Dim] = None,
51
51
  share_embedding: bool = None,
52
+ input_embedding: bool = True,
52
53
  input_embedding_scale: float = None,
53
54
  input_dropout: float = None,
54
55
  logits_with_bias: bool = False,
@@ -72,6 +73,7 @@ class TransformerDecoder(rf.Module):
72
73
  :param layer_opts: options for the decoder layer
73
74
  :param embed_dim: if given, will first have an embedding [vocab,embed] and then a linear [embed,model].
74
75
  :param share_embedding:
76
+ :param input_embedding: whether to use input embedding. If False, you must provide input of dimension model_dim.
75
77
  :param input_embedding_scale:
76
78
  :param input_dropout:
77
79
  :param logits_with_bias:
@@ -103,7 +105,7 @@ class TransformerDecoder(rf.Module):
103
105
 
104
106
  # We could make this optional or configurable if we ever need to.
105
107
  # Or maybe you would just have another separate implementation of this module then...
106
- self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim)
108
+ self.input_embedding = rf.Embedding(vocab_dim, embed_dim or model_dim) if input_embedding else None
107
109
 
108
110
  self.input_embedding_proj = None
109
111
  if embed_dim:
@@ -121,21 +123,31 @@ class TransformerDecoder(rf.Module):
121
123
  raise TypeError(f"unexpected pos_enc type {pos_enc!r}")
122
124
  self.pos_enc = pos_enc
123
125
  if share_embedding is None:
124
- if BehaviorVersion.get() < 20:
125
- logging.getLogger("returnn.frontend").warning(
126
- "TransformerDecoder share_embedding default is False"
127
- f" with your behavior version {BehaviorVersion.get()}."
128
- " Explicitly set share_embedding or switch to a new behavior version >= 20."
129
- )
130
- share_embedding = True if BehaviorVersion.get() >= 20 else False
126
+ if embed_dim and embed_dim != model_dim:
127
+ share_embedding = False
128
+ elif input_embedding:
129
+ if BehaviorVersion.get() < 20:
130
+ logging.getLogger("returnn.frontend").warning(
131
+ "TransformerDecoder share_embedding default is False"
132
+ f" with your behavior version {BehaviorVersion.get()}."
133
+ " Explicitly set share_embedding or switch to a new behavior version >= 20."
134
+ )
135
+ share_embedding = True if BehaviorVersion.get() >= 20 else False
136
+ else: # not input_embedding
137
+ share_embedding = False
131
138
  if input_embedding_scale is None:
132
- if BehaviorVersion.get() < 20:
133
- logging.getLogger("returnn.frontend").warning(
134
- "TransformerDecoder input_embedding_scale default is suboptimal"
135
- f" with your behavior version {BehaviorVersion.get()}."
136
- " Explicitly set input_embedding_scale or switch to a new behavior version >= 20."
137
- )
138
- input_embedding_scale = model_dim.dimension**0.5 if BehaviorVersion.get() >= 20 else 1.0
139
+ if input_embedding:
140
+ if BehaviorVersion.get() < 20:
141
+ logging.getLogger("returnn.frontend").warning(
142
+ "TransformerDecoder input_embedding_scale default is suboptimal"
143
+ f" with your behavior version {BehaviorVersion.get()}."
144
+ " Explicitly set input_embedding_scale or switch to a new behavior version >= 20."
145
+ )
146
+ input_embedding_scale = model_dim.dimension**0.5 if BehaviorVersion.get() >= 20 else 1.0
147
+ elif pos_enc:
148
+ input_embedding_scale = model_dim.dimension**0.5
149
+ else:
150
+ input_embedding_scale = 1.0
139
151
  self.input_embedding_scale = input_embedding_scale
140
152
  if input_dropout is None:
141
153
  if dropout > 0 and BehaviorVersion.get() < 20:
@@ -179,7 +191,9 @@ class TransformerDecoder(rf.Module):
179
191
  self.logits = rf.Linear(model_dim, vocab_dim, with_bias=logits_with_bias)
180
192
 
181
193
  if share_embedding:
182
- assert not embed_dim and not logits_with_bias, "not supported together with share_embedding"
194
+ assert input_embedding, "input_embedding=True required for share_embedding"
195
+ assert not embed_dim or embed_dim == model_dim, f"{embed_dim=} not supported with share_embedding"
196
+ assert not logits_with_bias, "logits_with_bias=True expected with share_embedding"
183
197
  self.logits.weight = self.input_embedding.weight
184
198
 
185
199
  def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> rf.State:
@@ -219,7 +233,12 @@ class TransformerDecoder(rf.Module):
219
233
  """
220
234
  new_state = rf.State()
221
235
 
222
- decoded = self.input_embedding(source) * self.input_embedding_scale
236
+ if self.input_embedding is not None:
237
+ decoded = self.input_embedding(source)
238
+ else:
239
+ decoded = source
240
+ if self.input_embedding_scale != 1:
241
+ decoded = decoded * self.input_embedding_scale
223
242
  if self.pos_enc is not None:
224
243
  decoded = decoded + self.pos_enc(spatial_dim=spatial_dim, offset=state.pos)
225
244
  decoded = rf.dropout(decoded, self.input_dropout)
@@ -273,6 +273,7 @@ class ConformerEncoderLayer(rf.Module):
273
273
  x_mhsa = self.self_att(x_mhsa_ln, axis=spatial_dim)
274
274
  x_mhsa = rf.dropout(x_mhsa, self.dropout, axis=self.dropout_broadcast and self.out_dim)
275
275
  x_mhsa_out = x_mhsa + x_ffn1_out
276
+ del x_mhsa
276
277
 
277
278
  # Conv
278
279
  x_conv_ln = self.conv_layer_norm(x_mhsa_out)
@@ -79,6 +79,8 @@ class TransformerEncoder(rf.Module):
79
79
  self.model_dim = model_dim
80
80
  self.embed_dim = embed_dim
81
81
 
82
+ self.out_dim = self.model_dim # alias. consistency, compatibility
83
+
82
84
  if input_embedding is None or isinstance(input_embedding, rf.Module):
83
85
  pass
84
86
  elif isinstance(input_embedding, type):
returnn/frontend/loss.py CHANGED
@@ -3,11 +3,12 @@ Loss functions
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+ from typing import Optional, Tuple
6
7
  from returnn.tensor import Tensor, Dim
7
8
  import returnn.frontend as rf
8
9
 
9
10
 
10
- __all__ = ["cross_entropy", "ctc_loss", "edit_distance"]
11
+ __all__ = ["cross_entropy", "ctc_loss", "ctc_greedy_decode", "edit_distance"]
11
12
 
12
13
 
13
14
  def cross_entropy(
@@ -93,6 +94,44 @@ def ctc_loss(
93
94
  )
94
95
 
95
96
 
97
+ def ctc_greedy_decode(
98
+ logits: Tensor,
99
+ *,
100
+ in_spatial_dim: Dim,
101
+ blank_index: int,
102
+ out_spatial_dim: Optional[Dim] = None,
103
+ target_dim: Optional[Dim] = None,
104
+ wb_target_dim: Optional[Dim] = None,
105
+ ) -> Tuple[Tensor, Dim]:
106
+ """
107
+ Greedy CTC decode.
108
+
109
+ :return: (labels, out_spatial_dim)
110
+ """
111
+ if wb_target_dim is None:
112
+ assert logits.feature_dim
113
+ wb_target_dim = logits.feature_dim
114
+
115
+ labels = rf.reduce_argmax(logits, axis=wb_target_dim)
116
+ labels = rf.cast(labels, "int32")
117
+
118
+ labels_shifted = rf.shift_right(labels, axis=in_spatial_dim, pad_value=blank_index)
119
+ mask_repeat = labels != labels_shifted
120
+ labels, out_spatial_dim = rf.masked_select(
121
+ labels,
122
+ mask=(labels != blank_index) & mask_repeat,
123
+ dims=[in_spatial_dim],
124
+ out_dim=out_spatial_dim,
125
+ )
126
+
127
+ if target_dim:
128
+ # Set correct sparse_dim. Only currently implemented if blank comes after.
129
+ assert target_dim.dimension == blank_index
130
+ labels.sparse_dim = target_dim
131
+
132
+ return labels, out_spatial_dim
133
+
134
+
96
135
  def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, *, dtype: str = "int32") -> Tensor:
97
136
  """
98
137
  :param a: [B,Ta]
@@ -274,10 +274,17 @@ class Functional(Module):
274
274
  (This is often not necessary, but sometimes useful.)
275
275
  """
276
276
 
277
- def __init__(self, func):
277
+ def __init__(self, func, *, attribs: Optional[Dict[str, Any]] = None):
278
+ """
279
+ :param func: callable. you might want to use functools.partial if you want to fix some arguments.
280
+ :param attribs: optional dict of attributes to set on this module. e.g. ``out_dim``.
281
+ """
278
282
  super().__init__()
279
283
  assert callable(func)
280
284
  self.func = func
285
+ if attribs:
286
+ for k, v in attribs.items():
287
+ setattr(self, k, v)
281
288
 
282
289
  def __repr__(self):
283
290
  return f"{self.__class__.__name__}({self.func.__qualname__})"
@@ -275,6 +275,8 @@ def _masked_select(
275
275
  return s
276
276
  assert s in dim_map
277
277
  return dim_map[s]
278
+ if s is None:
279
+ return None
278
280
  raise TypeError(f"_masked_select: unexpected type ({type(s)})")
279
281
 
280
282
 
@@ -346,6 +348,7 @@ def _masked_scatter_merge_dims(
346
348
  merged_dim_map: Dict[Dim, Dim],
347
349
  ) -> T:
348
350
  if isinstance(s, Dim):
351
+ assert isinstance(backup, Dim)
349
352
  # This is slightly more complex than in the _masked_select case:
350
353
  # We need to merge the s and backup depending on the mask.
351
354
  if s in reverse_dim_map:
@@ -353,7 +356,10 @@ def _masked_scatter_merge_dims(
353
356
  if s == backup:
354
357
  return s
355
358
  if s in merged_dim_map:
359
+ # If this assert fails, see e.g. https://github.com/rwth-i6/returnn/pull/1759 for an example.
360
+ assert backup in merged_dim_map, f"nested masked_scatter: mismatch of s {s} vs backup {backup}"
356
361
  return merged_dim_map[s]
362
+ assert backup not in merged_dim_map, f"nested masked_scatter: mismatch of s {s} vs backup {backup}"
357
363
  # Note: s/backup might even be static dims.
358
364
  new_size = _masked_scatter(
359
365
  s.get_size_tensor(),
@@ -416,6 +422,9 @@ def _masked_scatter(
416
422
  if s in merged_dim_map:
417
423
  return merged_dim_map[s]
418
424
  return s
425
+ if s is None:
426
+ assert backup is None
427
+ return None
419
428
  raise TypeError(f"_masked_scatter: unexpected type ({type(s)})")
420
429
 
421
430