returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +130 -42
  5. returnn/datasets/meta.py +93 -43
  6. returnn/datasets/postprocessing.py +597 -108
  7. returnn/datasets/util/vocabulary.py +90 -0
  8. returnn/frontend/__init__.py +1 -0
  9. returnn/frontend/_backend.py +41 -0
  10. returnn/frontend/_native/__init__.py +22 -0
  11. returnn/frontend/_numpy_backend.py +7 -0
  12. returnn/frontend/_utils.py +1 -1
  13. returnn/frontend/array_.py +48 -2
  14. returnn/frontend/assert_.py +35 -0
  15. returnn/frontend/attention.py +54 -20
  16. returnn/frontend/conv.py +273 -54
  17. returnn/frontend/device.py +14 -1
  18. returnn/frontend/encoder/conformer.py +20 -0
  19. returnn/frontend/encoder/transformer.py +2 -0
  20. returnn/frontend/loss.py +222 -3
  21. returnn/frontend/math_.py +54 -14
  22. returnn/native_op.cpp +182 -172
  23. returnn/native_op.py +36 -31
  24. returnn/sprint/cache.py +12 -13
  25. returnn/tensor/_dim_extra.py +7 -7
  26. returnn/tensor/_tensor_extra.py +10 -10
  27. returnn/tensor/utils.py +8 -5
  28. returnn/tf/frontend_layers/_backend.py +7 -3
  29. returnn/tf/layers/basic.py +27 -40
  30. returnn/tf/native_op.py +27 -63
  31. returnn/tf/network.py +1 -1
  32. returnn/tf/util/basic.py +22 -197
  33. returnn/torch/engine.py +157 -6
  34. returnn/torch/frontend/_backend.py +280 -29
  35. returnn/torch/frontend/bridge.py +61 -0
  36. returnn/torch/frontend/compile_helper.py +106 -0
  37. returnn/torch/util/array_.py +30 -0
  38. returnn/torch/util/assert_.py +122 -0
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/torch/util/native_op.py +885 -0
  41. returnn/torch/util/native_op_code_compiler.py +308 -0
  42. returnn/util/basic.py +6 -7
  43. returnn/util/better_exchook.py +4 -0
  44. returnn/util/cuda_env.py +332 -0
  45. returnn/util/debug.py +12 -2
  46. returnn/util/file_cache.py +15 -1
  47. returnn/util/fsa.py +17 -13
  48. returnn/util/native_code_compiler.py +104 -47
  49. returnn/util/task_system.py +1 -1
  50. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
  51. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
  52. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  53. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  54. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,8 @@ from returnn.frontend import _random_journal
23
23
  from returnn.frontend import _utils
24
24
 
25
25
  from . import raw_ops
26
+ from ..util import native_op
27
+ from ..util.assert_ import assert_
26
28
 
27
29
  _TT = Tensor[torch.Tensor]
28
30
 
@@ -44,6 +46,12 @@ class TorchBackend(Backend[torch.Tensor]):
44
46
  """
45
47
  return True
46
48
 
49
+ @staticmethod
50
+ def assert_(condition: Tensor, message: str):
51
+ """assert"""
52
+ assert condition.dims == (), "condition for assert must be a scalar"
53
+ assert_(condition.raw_tensor, message)
54
+
47
55
  @staticmethod
48
56
  def set_random_seed(seed: int):
49
57
  """
@@ -275,7 +283,7 @@ class TorchBackend(Backend[torch.Tensor]):
275
283
  :return: tensor
276
284
  """
277
285
  assert len(dims) >= 2
278
- first_axis = min(source.dims.index(d) for d in dims)
286
+ first_axis = min([source.dims.index(d) for d in dims])
279
287
  pre_dims = source.dims[:first_axis]
280
288
  post_dims = [d for d in source.dims if d not in dims and d not in pre_dims]
281
289
  source = source.copy_transpose(tuple(pre_dims) + tuple(dims) + tuple(post_dims), allow_int=False)
@@ -666,10 +674,10 @@ class TorchBackend(Backend[torch.Tensor]):
666
674
  targets_spatial_dim: Dim,
667
675
  blank_index: int,
668
676
  max_approx: bool = False,
677
+ use_native_op: Optional[bool] = None,
678
+ label_loop: bool = True,
669
679
  ) -> Tensor:
670
680
  """CTC"""
671
- if max_approx:
672
- raise NotImplementedError("ctc_loss: max_approx not implemented for PyTorch")
673
681
  assert targets.sparse_dim and targets.sparse_dim.dimension <= logits.feature_dim.dimension
674
682
  # PyTorch expects the logits to be of shape (T, B, C) where T is the input spatial dim.
675
683
  batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
@@ -707,18 +715,42 @@ class TorchBackend(Backend[torch.Tensor]):
707
715
  if len(batch_dims) != 1:
708
716
  targets_raw = torch.reshape(targets_raw, (batch_n_elems, targets_raw.shape[-1])) # [B', S]
709
717
  targets_lengths = torch.reshape(targets_lengths, (batch_n_elems,)) # [B']
710
- if log_probs.dtype == torch.bfloat16:
711
- # Currently (PyTorch 2.5), ctc_loss does not support bfloat16.
712
- log_probs = log_probs.to(torch.float32)
713
- loss_raw = torch.nn.functional.ctc_loss(
714
- log_probs=log_probs,
715
- targets=targets_raw,
716
- input_lengths=input_lengths,
717
- target_lengths=targets_lengths,
718
- blank=blank_index,
719
- zero_infinity=True,
720
- reduction="none",
721
- )
718
+ if use_native_op is None:
719
+ if max_approx or not label_loop:
720
+ use_native_op = True
721
+ else:
722
+ # This was the current default.
723
+ # We might change the default in the future, maybe via new behavior version.
724
+ use_native_op = False
725
+ if use_native_op:
726
+ loss_raw = native_op.ctc_loss(
727
+ logits=log_probs,
728
+ logits_normalize=True,
729
+ logits_seq_lens=input_lengths,
730
+ logits_time_major=True,
731
+ targets=targets_raw,
732
+ targets_seq_lens=targets_lengths,
733
+ blank_index=blank_index,
734
+ max_approx=max_approx,
735
+ label_loop=label_loop,
736
+ )
737
+ else: # not native_op
738
+ if max_approx:
739
+ raise NotImplementedError("ctc_loss: max_approx not implemented for PyTorch")
740
+ if not label_loop:
741
+ raise NotImplementedError("ctc_loss: label_loop=False not implemented for PyTorch")
742
+ if log_probs.dtype == torch.bfloat16:
743
+ # Currently (PyTorch 2.5), ctc_loss does not support bfloat16.
744
+ log_probs = log_probs.to(torch.float32)
745
+ loss_raw = torch.nn.functional.ctc_loss(
746
+ log_probs=log_probs,
747
+ targets=targets_raw,
748
+ input_lengths=input_lengths,
749
+ target_lengths=targets_lengths,
750
+ blank=blank_index,
751
+ zero_infinity=True,
752
+ reduction="none",
753
+ )
722
754
  if len(batch_dims) != 1:
723
755
  loss_raw = torch.reshape(loss_raw, logits_raw_shape[1:-1])
724
756
  loss = Tensor(
@@ -729,6 +761,103 @@ class TorchBackend(Backend[torch.Tensor]):
729
761
  )
730
762
  return loss
731
763
 
764
+ @staticmethod
765
+ def ctc_best_path(
766
+ *,
767
+ logits: Tensor,
768
+ logits_normalized: bool = False,
769
+ targets: Tensor,
770
+ input_spatial_dim: Dim,
771
+ targets_spatial_dim: Dim,
772
+ blank_index: int,
773
+ label_loop: bool = True,
774
+ ) -> Tensor:
775
+ """CTC best path"""
776
+ assert targets.sparse_dim and targets.sparse_dim.dimension <= logits.feature_dim.dimension
777
+ # PyTorch expects the logits to be of shape (T, B, C) where T is the input spatial dim.
778
+ batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
779
+ batch_dims_targets = targets.remaining_dims(targets_spatial_dim)
780
+ if set(batch_dims) != set(batch_dims_targets):
781
+ # Need to broadcast.
782
+ logits = rf.expand_dims(logits, [d for d in batch_dims_targets if d not in batch_dims])
783
+ targets = rf.expand_dims(targets, [d for d in batch_dims if d not in batch_dims_targets])
784
+ batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
785
+ batch_shape = [d.get_dim_value() for d in batch_dims]
786
+ batch_n_elems = prod(batch_shape)
787
+ logits = logits.copy_transpose([input_spatial_dim] + batch_dims + [logits.feature_dim])
788
+ logits_raw: torch.Tensor = logits.raw_tensor
789
+ input_lengths: torch.Tensor = input_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(batch_dims)
790
+ if input_lengths.numel() != batch_n_elems:
791
+ input_lengths = input_lengths.expand(batch_shape)
792
+ if len(batch_dims) != 1:
793
+ logits_raw = torch.reshape(
794
+ logits_raw, logits_raw.shape[:1] + (batch_n_elems,) + logits_raw.shape[-1:]
795
+ ) # [T, B', C]
796
+ input_lengths = torch.reshape(input_lengths, (batch_n_elems,)) # [B']
797
+ if logits_normalized:
798
+ log_probs = logits_raw
799
+ else:
800
+ log_probs = torch.nn.functional.log_softmax(logits_raw, dim=-1)
801
+ # PyTorch expects the targets to be of shape (B, S) where S is the targets spatial dim.
802
+ targets_raw = targets.copy_compatible_to_dims_raw(batch_dims + [targets_spatial_dim]) # [B..., S]
803
+ targets_raw_shape = batch_shape + [targets_spatial_dim.get_dim_value()]
804
+ if targets_raw.numel() != prod(targets_raw_shape):
805
+ targets_raw = targets_raw.expand(targets_raw_shape)
806
+ targets_lengths = targets_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(batch_dims)
807
+ if targets_lengths.numel() != batch_n_elems:
808
+ targets_lengths = targets_lengths.expand(batch_shape)
809
+ if len(batch_dims) != 1:
810
+ targets_raw = torch.reshape(targets_raw, (batch_n_elems, targets_raw.shape[-1])) # [B', S]
811
+ targets_lengths = torch.reshape(targets_lengths, (batch_n_elems,)) # [B']
812
+ alignment_raw = native_op.ctc_best_path(
813
+ logits=log_probs,
814
+ logits_normalize=True,
815
+ logits_seq_lens=input_lengths,
816
+ logits_time_major=True,
817
+ targets=targets_raw,
818
+ targets_seq_lens=targets_lengths,
819
+ blank_index=blank_index,
820
+ label_loop=label_loop,
821
+ ) # (time,batch)
822
+ if len(batch_dims) != 1:
823
+ alignment_raw = torch.reshape(alignment_raw, log_probs.shape[:-1])
824
+ alignment = Tensor(
825
+ name="ctc_best_path",
826
+ dims=[input_spatial_dim] + batch_dims,
827
+ sparse_dim=logits.feature_dim,
828
+ raw_tensor=alignment_raw,
829
+ dtype=TorchBackend.get_dtype_name_raw(alignment_raw),
830
+ )
831
+ return alignment
832
+
833
+ @staticmethod
834
+ def have_edit_distance() -> bool:
835
+ """whether edit distance is available"""
836
+ return True
837
+
838
+ @staticmethod
839
+ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim) -> Tensor:
840
+ """edit distance"""
841
+ a_batch_dims = a.remaining_dims(a_spatial_dim)
842
+ b_batch_dims = b.remaining_dims(b_spatial_dim)
843
+ assert set(a_batch_dims) == set(b_batch_dims), "edit_distance: batch dims must match"
844
+ a_raw = a.copy_compatible_to_dims_raw(a_batch_dims + [a_spatial_dim])
845
+ b_raw = b.copy_compatible_to_dims_raw(a_batch_dims + [b_spatial_dim])
846
+ a_seq_len = a_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(a_batch_dims)
847
+ b_seq_len = b_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(a_batch_dims)
848
+ batch_shape = None
849
+ if len(a_batch_dims) != 1:
850
+ batch_shape = [d.get_dim_value() for d in a_batch_dims]
851
+ batch_n_elems = prod(batch_shape)
852
+ a_raw = torch.reshape(a_raw.raw_tensor, (batch_n_elems, a_spatial_dim.get_dim_value()))
853
+ b_raw = torch.reshape(b_raw.raw_tensor, (batch_n_elems, b_spatial_dim.get_dim_value()))
854
+ a_seq_len = torch.reshape(a_seq_len.raw_tensor, (batch_n_elems,))
855
+ b_seq_len = torch.reshape(b_seq_len.raw_tensor, (batch_n_elems,))
856
+ dist_raw = native_op.edit_distance(a_raw, a_seq_len, b_raw, b_seq_len)
857
+ if len(a_batch_dims) != 1:
858
+ dist_raw = torch.reshape(dist_raw, batch_shape)
859
+ return rf.convert_to_tensor(dist_raw, name="edit_distance", dims=a_batch_dims)
860
+
732
861
  @staticmethod
733
862
  def create_parameter_raw(tensor: rf.Parameter, *, device: Optional[str] = None) -> torch.nn.Parameter:
734
863
  """
@@ -884,7 +1013,7 @@ class TorchBackend(Backend[torch.Tensor]):
884
1013
  :param perm: e.g. [0, 2, 1]
885
1014
  :return: permuted (transposed) raw tensor; wraps torch.permute
886
1015
  """
887
- if all(p == i for i, p in enumerate(perm)):
1016
+ if all([p == i for i, p in enumerate(perm)]):
888
1017
  return raw_tensor
889
1018
  return torch.permute(raw_tensor, tuple(perm))
890
1019
 
@@ -1166,20 +1295,29 @@ class TorchBackend(Backend[torch.Tensor]):
1166
1295
  if start is None:
1167
1296
  start = 0
1168
1297
  if isinstance(size, Dim):
1298
+ assert end is None
1169
1299
  size = size.get_dim_value()
1170
1300
  elif isinstance(size, Tensor):
1301
+ assert end is None
1171
1302
  assert size.dims == () # scalar
1172
1303
  size = size.raw_tensor
1173
- if size is not None:
1174
- assert end is None
1175
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1176
- else:
1304
+ elif isinstance(size, int):
1305
+ pass
1306
+ elif size is None:
1177
1307
  if isinstance(end, Tensor):
1178
1308
  assert end.dims == ()
1179
1309
  end = end.raw_tensor
1180
- if end is None:
1310
+ elif isinstance(end, int):
1311
+ if end < 0:
1312
+ end += axis.get_dim_value()
1313
+ elif end is None:
1181
1314
  end = axis.get_dim_value()
1182
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=end - start)
1315
+ else:
1316
+ raise TypeError(f"slice: unsupported type for end: {type(end)}")
1317
+ size = end - start
1318
+ else:
1319
+ raise TypeError(f"slice: unsupported type for size: {type(size)}")
1320
+ out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1183
1321
  return out
1184
1322
 
1185
1323
  @staticmethod
@@ -1352,12 +1490,24 @@ class TorchBackend(Backend[torch.Tensor]):
1352
1490
  a_dims = a.dims
1353
1491
  b_dims = b.dims
1354
1492
 
1355
- assert all(dim in a_dims for dim in reduce), (
1356
- f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})"
1357
- )
1358
- assert all(dim in b_dims for dim in reduce), (
1359
- f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})"
1360
- )
1493
+ if not all(dim in a_dims for dim in reduce) or not all(dim in b_dims for dim in reduce):
1494
+ # revert to the generic einsum implementation
1495
+ assert all(dim in a_dims + b_dims for dim in reduce), "Some reduce Dims not in a or b."
1496
+ result_dims = [dim for dim in a_dims if dim not in reduce] + [
1497
+ dim for dim in b_dims if dim not in reduce and dim not in a_dims
1498
+ ]
1499
+ map_to_letter = {}
1500
+ for dim in a_dims + b_dims:
1501
+ if dim not in map_to_letter:
1502
+ map_to_letter[dim] = chr(97 + len(map_to_letter)) # 'a', 'b', 'c', ...
1503
+ a_subscript = "".join(map_to_letter[dim] for dim in a_dims)
1504
+ b_subscript = "".join(map_to_letter[dim] for dim in b_dims)
1505
+ out_subscript = "".join(map_to_letter[dim] for dim in result_dims)
1506
+ raw_result = torch.einsum(f"{a_subscript},{b_subscript}->{out_subscript}", a.raw_tensor, b.raw_tensor)
1507
+ result_tensor = Tensor(
1508
+ "einsum", dims=result_dims, raw_tensor=raw_result, dtype=TorchBackend.get_dtype_name_raw(raw_result)
1509
+ )
1510
+ return result_tensor
1361
1511
 
1362
1512
  if len(reduce) > 1:
1363
1513
  reduce = list(reduce)
@@ -1767,6 +1917,9 @@ class TorchBackend(Backend[torch.Tensor]):
1767
1917
  remaining_dims = [d for d in tensor.dims if d not in mask.dims]
1768
1918
  tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
1769
1919
  in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
1920
+ if any([in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)]):
1921
+ # unbroadcast
1922
+ in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
1770
1923
  if mask.raw_tensor.device.type == "meta":
1771
1924
  # This is not supported, but also, we would anyway not know the out shape.
1772
1925
  # However, instead of erroring, just assume some dummy mask.
@@ -1920,7 +2073,7 @@ class TorchBackend(Backend[torch.Tensor]):
1920
2073
  if not out_spatial_dims:
1921
2074
  out_spatial_dims = rf.make_conv_out_spatial_dims(
1922
2075
  in_spatial_dims=in_spatial_dims,
1923
- filter_size=[d.dimension for d in filter_size],
2076
+ filter_size=filter_size,
1924
2077
  strides=strides or 1,
1925
2078
  dilation_rate=dilation_rate or 1,
1926
2079
  padding=padding,
@@ -2033,6 +2186,104 @@ class TorchBackend(Backend[torch.Tensor]):
2033
2186
  out.feature_dim = out_dim
2034
2187
  return out, out_spatial_dims
2035
2188
 
2189
+ # noinspection PyShadowingBuiltins
2190
+ @staticmethod
2191
+ def transposed_conv(
2192
+ source: Tensor,
2193
+ *,
2194
+ in_dim: Dim,
2195
+ out_dim: Dim,
2196
+ in_spatial_dims: Sequence[Dim],
2197
+ out_spatial_dims: Optional[Sequence[Dim]] = None,
2198
+ filter: Tensor,
2199
+ filter_size: Sequence[Dim],
2200
+ padding: str,
2201
+ remove_padding: Union[Sequence[int], int] = 0,
2202
+ output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
2203
+ strides: Optional[Sequence[int]] = None,
2204
+ bias: Optional[Tensor] = None,
2205
+ ) -> Tuple[Tensor, Sequence[Dim]]:
2206
+ """transposed convolution"""
2207
+ if not out_spatial_dims:
2208
+ out_spatial_dims = rf.make_transposed_conv_out_spatial_dims(
2209
+ in_spatial_dims=in_spatial_dims,
2210
+ filter_size=filter_size,
2211
+ strides=strides,
2212
+ padding=padding,
2213
+ output_padding=output_padding,
2214
+ )
2215
+ assert remove_padding == 0 # not implemented yet otherwise...
2216
+ if strides is None:
2217
+ strides = [fs.dimension for fs in filter_size]
2218
+ filter_dims = (in_dim, out_dim) + tuple(filter_size)
2219
+ filter = filter.copy_transpose(filter_dims)
2220
+ batch_dims = [d for d in source.dims if d not in (in_dim,) + tuple(in_spatial_dims)]
2221
+ # Torch conv expects (N,C,<spatial dims>) as shape.
2222
+ source = source.copy_transpose(batch_dims + [in_dim] + list(in_spatial_dims))
2223
+ if len(batch_dims) == 1:
2224
+ src_raw = source.raw_tensor
2225
+ else:
2226
+ src_raw = torch.reshape(
2227
+ source.raw_tensor,
2228
+ # potentially merge batch dims all together
2229
+ [-1, in_dim.get_dim_value()] + [d.get_dim_value() for d in in_spatial_dims],
2230
+ )
2231
+ if padding == "same":
2232
+ raise NotImplementedError("transposed_conv with padding='same' not implemented")
2233
+ if padding == "valid":
2234
+ padding_val = 0
2235
+ else:
2236
+ raise ValueError(f"invalid padding {padding!r}, expected 'same' or 'valid'")
2237
+ if len(filter_size) == 1:
2238
+ out_raw = torch.nn.functional.conv_transpose1d(
2239
+ src_raw,
2240
+ weight=filter.raw_tensor,
2241
+ bias=bias.raw_tensor if bias is not None else None,
2242
+ stride=strides,
2243
+ padding=padding_val,
2244
+ output_padding=output_padding or 0,
2245
+ )
2246
+ elif len(filter_size) == 2:
2247
+ out_raw = torch.nn.functional.conv_transpose2d(
2248
+ src_raw,
2249
+ weight=filter.raw_tensor,
2250
+ bias=bias.raw_tensor if bias is not None else None,
2251
+ stride=strides,
2252
+ padding=padding_val,
2253
+ output_padding=output_padding or 0,
2254
+ )
2255
+ elif len(filter_size) == 3:
2256
+ out_raw = torch.nn.functional.conv_transpose3d(
2257
+ src_raw,
2258
+ weight=filter.raw_tensor,
2259
+ bias=bias.raw_tensor if bias is not None else None,
2260
+ stride=strides,
2261
+ padding=padding_val,
2262
+ output_padding=output_padding or 0,
2263
+ )
2264
+ else:
2265
+ raise ValueError(f"invalid number of filter dims {filter_size}, expected 1, 2, or 3")
2266
+ if remove_padding:
2267
+ if isinstance(remove_padding, int):
2268
+ remove_padding = [remove_padding] * len(out_spatial_dims)
2269
+ assert len(remove_padding) == len(out_spatial_dims)
2270
+ slices = [slice(None)] * out_raw.ndim
2271
+ for i, pad in enumerate(remove_padding):
2272
+ if pad > 0:
2273
+ slices[2 + i] = slice(0, -pad)
2274
+ out_raw = out_raw[tuple(slices)]
2275
+ out = Tensor(
2276
+ "transposed_conv",
2277
+ dims=batch_dims + [out_dim] + list(out_spatial_dims),
2278
+ dtype=TorchBackend.get_dtype_name_raw(out_raw),
2279
+ )
2280
+ if len(batch_dims) == 1:
2281
+ out.raw_tensor = out_raw
2282
+ else:
2283
+ out.raw_tensor = torch.reshape(out_raw, [d.get_dim_value() for d in out.dims])
2284
+ out.feature_dim = out_dim
2285
+ return out, out_spatial_dims
2286
+
2036
2287
  @staticmethod
2037
2288
  def pool(
2038
2289
  source: Tensor,
@@ -136,6 +136,15 @@ class RFModuleAsPTModule(torch.nn.Module):
136
136
  def _get_name(self):
137
137
  return self._rf_module.__class__.__name__ + "[RF→PT]"
138
138
 
139
+ def __repr__(self) -> str:
140
+ """
141
+ Return a custom repr for Sequential/ModuleList that compresses repeated module representations if possible,
142
+ otherwise fallback to default behavior.
143
+ """
144
+ if _can_use_compact_repr(self):
145
+ return _repr_compact(self)
146
+ return super().__repr__()
147
+
139
148
  @property
140
149
  def rf_module(self) -> rf.Module:
141
150
  """RF module"""
@@ -193,3 +202,55 @@ class RFModuleAsPTModule(torch.nn.Module):
193
202
  # See similar logic in torch.nn.Module._apply.
194
203
  pt_param = torch.nn.Parameter(tensor, tensor.requires_grad)
195
204
  rf_param.raw_tensor = pt_param
205
+
206
+
207
+ def _can_use_compact_repr(self: RFModuleAsPTModule) -> bool:
208
+ return list(self._modules.keys()) == [str(i) for i in range(len(self._modules))]
209
+
210
+
211
+ def _repr_compact(self: RFModuleAsPTModule) -> str:
212
+ """
213
+ Return a custom repr for Sequential/ModuleList that compresses repeated module representations.
214
+ Code copied and adapted from torch.nn.ModuleList.__repr__.
215
+ """
216
+ list_of_reprs = [repr(item) for item in self._modules.values()]
217
+ if len(list_of_reprs) == 0:
218
+ return self._get_name() + "()"
219
+
220
+ start_end_indices = [[0, 0]]
221
+ repeated_blocks = [list_of_reprs[0]]
222
+ for i, r in enumerate(list_of_reprs[1:], 1):
223
+ if r == repeated_blocks[-1]:
224
+ start_end_indices[-1][1] += 1
225
+ continue
226
+
227
+ start_end_indices.append([i, i])
228
+ repeated_blocks.append(r)
229
+
230
+ lines = []
231
+ main_str = self._get_name() + "("
232
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
233
+ local_repr = f"({start_id}): {b}" # default repr
234
+
235
+ if start_id != end_id:
236
+ n = end_id - start_id + 1
237
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
238
+
239
+ local_repr = _add_indent(local_repr, 2)
240
+ lines.append(local_repr)
241
+
242
+ main_str += "\n " + "\n ".join(lines) + "\n"
243
+ main_str += ")"
244
+ return main_str
245
+
246
+
247
+ def _add_indent(s_: str, num_spaces: int) -> str:
248
+ s = s_.split("\n")
249
+ # don't do anything for single-line stuff
250
+ if len(s) == 1:
251
+ return s_
252
+ first = s.pop(0)
253
+ s = [(num_spaces * " ") + line for line in s]
254
+ s = "\n".join(s)
255
+ s = first + "\n" + s
256
+ return s
@@ -0,0 +1,106 @@
1
+ """
2
+ Helpers to improve torch.compile on RF code.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import Any, Iterable, List, Tuple
7
+
8
+ import os
9
+ from returnn.tensor import Tensor, Dim
10
+
11
+ # noinspection PyProtectedMember
12
+ from returnn.frontend import _native
13
+
14
+ _is_set_up = False
15
+
16
+
17
+ def setup():
18
+ """
19
+ Set up the torch.compile helpers for RF code, also including :class:`Tensor` and :class:`Dim`.
20
+ """
21
+
22
+ global _is_set_up
23
+ if _is_set_up:
24
+ return
25
+ _is_set_up = True # only try once
26
+
27
+ assert not _native.is_set_up(), "Call this setup() as early as possible."
28
+ _native.set_enabled(False)
29
+
30
+ # We have lots of dynamic shapes.
31
+ os.environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
32
+
33
+ # noinspection PyProtectedMember
34
+ from torch.utils._pytree import register_pytree_node
35
+
36
+ register_pytree_node(Tensor, _tensor_flatten, _tensor_unflatten)
37
+ register_pytree_node(Dim, _dim_flatten, _dim_unflatten)
38
+
39
+ Dim.get_dim_value = _dim_get_dim_value
40
+
41
+
42
+ def _tensor_flatten(t: Tensor) -> Tuple[List[Any], Any]:
43
+ """
44
+ Flatten the tensor for PyTree.
45
+ """
46
+ return [t.raw_tensor, t.dims, t.sparse_dim], [
47
+ t.name,
48
+ t.dtype,
49
+ t.version,
50
+ t.feature_dim_axis_or_unspecified,
51
+ t.time_dim_axis_or_unspecified,
52
+ ]
53
+
54
+
55
+ def _tensor_unflatten(values: Iterable[Any], metadata: Any) -> Tensor:
56
+ """
57
+ Unflatten the tensor from PyTree.
58
+ """
59
+ raw_tensor, dims, sparse_dim = values
60
+ name, dtype, version, feature_dim_axis, time_dim_axis = metadata
61
+ return Tensor(
62
+ name=name,
63
+ dims=dims,
64
+ dtype=dtype,
65
+ sparse_dim=sparse_dim,
66
+ feature_dim_axis=feature_dim_axis,
67
+ time_dim_axis=time_dim_axis,
68
+ raw_tensor=raw_tensor,
69
+ version=version,
70
+ )
71
+
72
+
73
+ def _dim_flatten(d: Dim) -> Tuple[List[Any], Any]:
74
+ """
75
+ Flatten the dim for PyTree.
76
+ """
77
+ return [d.dyn_size_ext], [d.name, d.dimension, d.size]
78
+
79
+
80
+ def _dim_unflatten(values: Iterable[Any], metadata: Any) -> Dim:
81
+ """
82
+ Unflatten the dim from PyTree.
83
+ """
84
+ (dyn_size_ext,) = values
85
+ name, dimension, size = metadata
86
+ # TODO this creates a new instance... this is maybe wrong?
87
+ return Dim(name=name, dimension=dimension, size=size, dyn_size_ext=dyn_size_ext)
88
+
89
+
90
+ def _dim_get_dim_value(self: Dim) -> int:
91
+ """
92
+ Infers the dim this axis should have if unbroadcasted.
93
+ If `self.src_data` has a placeholder, will use the shape from there.
94
+ Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
95
+
96
+ :return: max(size or dyn_size)
97
+ """
98
+ res = self.get_dim_value_tensor()
99
+ if isinstance(res, Tensor):
100
+ assert res.dims == ()
101
+ assert res.raw_tensor is not None
102
+ # Specifically PyTorch would then treat it as a SymInt in torch.compile,
103
+ # which is important to have for some torch functions (e.g. torch.tile and others).
104
+ return int(res.raw_tensor)
105
+ assert isinstance(res, int)
106
+ return res
@@ -60,3 +60,33 @@ def nonzero(mask: torch.Tensor, *, out_len: Union[int, torch.Tensor]) -> torch.T
60
60
  idx = torch.argsort(mask.to(torch.int8), stable=True, descending=True) # [in_len]
61
61
  idx = idx[:out_len] # [out_len]
62
62
  return idx
63
+
64
+
65
+ def sequence_mask(lengths: torch.Tensor, *, maxlen: Optional[int] = None) -> torch.Tensor:
66
+ """
67
+ Creates a boolean mask from sequence lengths.
68
+
69
+ :param lengths: Tensor of shape [batch_size...] containing sequence lengths
70
+ :param maxlen: Maximum length of the sequences. If None, uses the maximum value in lengths.
71
+ :return: A boolean mask tensor of shape [batch_size..., maxlen]
72
+ """
73
+ if maxlen is None:
74
+ maxlen = lengths.max()
75
+ indices = torch.arange(0, maxlen, dtype=lengths.dtype, device=lengths.device)
76
+ mask = indices < lengths[..., None]
77
+ return mask
78
+
79
+
80
+ def sequence_mask_time_major(lengths: torch.Tensor, *, maxlen: Optional[int] = None) -> torch.Tensor:
81
+ """
82
+ Creates a boolean mask from sequence lengths.
83
+
84
+ :param lengths: Tensor of shape [batch_size...] containing sequence lengths
85
+ :param maxlen: Maximum length of the sequences. If None, uses the maximum value in lengths.
86
+ :return: A boolean mask tensor of shape [maxlen, batch_size...]
87
+ """
88
+ if maxlen is None:
89
+ maxlen = lengths.max()
90
+ indices = torch.arange(0, maxlen, dtype=lengths.dtype, device=lengths.device)
91
+ mask = indices[(slice(None),) + (None,) * lengths.ndim] < lengths[None]
92
+ return mask