returnn 1.20251027.224345__py3-none-any.whl → 1.20260113.134416__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (43) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/config.py +1 -1
  5. returnn/datasets/lm.py +130 -42
  6. returnn/datasets/meta.py +93 -43
  7. returnn/datasets/postprocessing.py +597 -108
  8. returnn/datasets/util/vocabulary.py +90 -0
  9. returnn/frontend/_native/__init__.py +22 -0
  10. returnn/frontend/_utils.py +1 -1
  11. returnn/frontend/array_.py +48 -2
  12. returnn/frontend/attention.py +54 -20
  13. returnn/frontend/conv.py +273 -54
  14. returnn/frontend/device.py +14 -1
  15. returnn/frontend/encoder/conformer.py +20 -0
  16. returnn/frontend/encoder/transformer.py +2 -0
  17. returnn/frontend/loss.py +40 -1
  18. returnn/frontend/math_.py +54 -14
  19. returnn/native_op.cpp +80 -0
  20. returnn/sprint/cache.py +12 -13
  21. returnn/tensor/_dim_extra.py +7 -7
  22. returnn/tensor/_tensor_extra.py +10 -10
  23. returnn/tensor/utils.py +7 -4
  24. returnn/tf/frontend_layers/_backend.py +4 -3
  25. returnn/tf/layers/basic.py +15 -39
  26. returnn/tf/native_op.py +11 -58
  27. returnn/tf/network.py +1 -1
  28. returnn/tf/util/basic.py +19 -0
  29. returnn/torch/engine.py +157 -6
  30. returnn/torch/frontend/_backend.py +137 -15
  31. returnn/torch/frontend/bridge.py +61 -0
  32. returnn/torch/frontend/compile_helper.py +106 -0
  33. returnn/torch/util/exception_helper.py +7 -1
  34. returnn/util/basic.py +5 -6
  35. returnn/util/better_exchook.py +4 -0
  36. returnn/util/debug.py +12 -2
  37. returnn/util/file_cache.py +15 -1
  38. returnn/util/task_system.py +1 -1
  39. {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/METADATA +2 -2
  40. {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/RECORD +43 -42
  41. {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/LICENSE +0 -0
  42. {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/WHEEL +0 -0
  43. {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/top_level.txt +0 -0
returnn/torch/engine.py CHANGED
@@ -3,9 +3,11 @@ Main engine for PyTorch
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+
6
7
  from typing import Optional, Any, Union, Callable, Dict, Set
7
8
  from contextlib import nullcontext, ExitStack, contextmanager
8
9
 
10
+ import sys
9
11
  import gc
10
12
  import os
11
13
  import time
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
20
22
  from torch.utils.data import DataLoader
21
23
  from torch import autocast
22
24
  from torch.cuda import amp
25
+ from torch.profiler import record_function
23
26
  import numpy as np
24
27
 
25
28
  import returnn
@@ -404,10 +407,14 @@ class Engine(EngineBase):
404
407
  total_data_size_packed = NumbersDict()
405
408
  total_data_size_padded = NumbersDict()
406
409
 
410
+ prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
411
+ if prof:
412
+ prof.__enter__()
413
+
407
414
  report_prefix = f"ep {self.epoch} train"
408
415
  try:
409
416
  while True:
410
- with torch.no_grad():
417
+ with torch.no_grad(), record_function("data_loading"):
411
418
  extern_data_raw = next(data_iter, None)
412
419
 
413
420
  step_begin_time = time.monotonic()
@@ -485,7 +492,8 @@ class Engine(EngineBase):
485
492
  with (
486
493
  self._ddp_pt_model.no_sync()
487
494
  if (self._ddp_pt_model is not None and not perform_update_step)
488
- else nullcontext()
495
+ else nullcontext(),
496
+ record_function("backward"),
489
497
  ):
490
498
  if self._grad_scaler is not None:
491
499
  self._grad_scaler.scale(total_loss.raw_tensor).backward()
@@ -500,7 +508,8 @@ class Engine(EngineBase):
500
508
 
501
509
  # only update the weights when every gradient accumulation loop ends
502
510
  if perform_update_step:
503
- self._updater.step(grad_scaler=self._grad_scaler)
511
+ with record_function("optimizer_step"):
512
+ self._updater.step(grad_scaler=self._grad_scaler)
504
513
  zero_grad_next_step = perform_update_step
505
514
 
506
515
  if self._torch_distributed_ctx:
@@ -532,7 +541,7 @@ class Engine(EngineBase):
532
541
  for key, val in eval_info.items():
533
542
  self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
534
543
  self._tensorboard_writer.add_scalar(
535
- f"train/learning_rate",
544
+ "train/learning_rate",
536
545
  self._updater.get_effective_learning_rate(),
537
546
  global_step=self.global_train_step,
538
547
  )
@@ -582,10 +591,19 @@ class Engine(EngineBase):
582
591
  self._updater.set_current_train_step(
583
592
  global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
584
593
  )
594
+
595
+ if prof:
596
+ prof.step()
597
+
585
598
  except Exception as exc:
599
+ if prof:
600
+ prof.__exit__(type(exc), exc, exc.__traceback__)
586
601
  help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
587
602
  raise
588
603
 
604
+ if prof:
605
+ prof.__exit__(None, None, None)
606
+
589
607
  elapsed = time.monotonic() - epoch_start_time
590
608
  elapsed_computation_percentage = elapsed_computation_time / elapsed
591
609
  total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
@@ -885,6 +903,7 @@ class Engine(EngineBase):
885
903
  if self._default_float_dtype:
886
904
  stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
887
905
  stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
906
+ stack.enter_context(record_function("model_step"))
888
907
  yield
889
908
 
890
909
  def _run_step(
@@ -930,7 +949,7 @@ class Engine(EngineBase):
930
949
  if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
931
950
  filename = model_epoch_filename
932
951
  print("Load model %s" % (filename,), file=log.v4)
933
- checkpoint_state = torch.load(filename, map_location=self._device)
952
+ checkpoint_state = _torch_load(filename, device=self._device)
934
953
  if epoch is None:
935
954
  epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
936
955
  step = checkpoint_state.get("step", 1)
@@ -1030,7 +1049,7 @@ class Engine(EngineBase):
1030
1049
  print("(No relevant parameters matching.)", file=log.v3)
1031
1050
  continue
1032
1051
  print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
1033
- preload_model_state = torch.load(opts["filename"], map_location=self._device)
1052
+ preload_model_state = _torch_load(opts["filename"], device=self._device)
1034
1053
  if opts.get("checkpoint_key", "model") is not None:
1035
1054
  # This can be used if an external checkpoint saves a checkpoint a different structure that just the
1036
1055
  # model state dict. E.g., if a checkpoint is created using
@@ -1063,6 +1082,28 @@ class Engine(EngineBase):
1063
1082
  preload_model_state_keys = set(preload_model_state.keys())
1064
1083
  loaded_state_keys.update(preload_model_state.keys())
1065
1084
  missing_keys.difference_update(preload_model_state.keys())
1085
+
1086
+ custom_missing_load_func = opts.get("custom_missing_load_func")
1087
+ if custom_missing_load_func:
1088
+ custom_missing_vars_map = {}
1089
+ for var_name in missing_keys_preload:
1090
+ var_shape = self._pt_model.state_dict()[var_name].shape
1091
+ var_val = custom_missing_load_func(
1092
+ name=var_name,
1093
+ shape=var_shape,
1094
+ preload_model_state=preload_model_state,
1095
+ **util.get_fwd_compat_kwargs(),
1096
+ )
1097
+ if var_val is not None:
1098
+ assert var_val.shape == var_shape
1099
+ custom_missing_vars_map[var_name] = var_val
1100
+ preload_model_state.update(custom_missing_vars_map)
1101
+ missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
1102
+ preload_model_state, strict=False
1103
+ )
1104
+ loaded_state_keys.update(preload_model_state.keys())
1105
+ missing_keys.difference_update(preload_model_state.keys())
1106
+
1066
1107
  del preload_model_state
1067
1108
  gc.collect()
1068
1109
 
@@ -1700,3 +1741,113 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1700
1741
  p=p,
1701
1742
  ).item()
1702
1743
  )
1744
+
1745
+
1746
+ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
1747
+ # Might resolve PtCheckpoint or Sisyphus Path objects or so.
1748
+ filename = os.fspath(filename)
1749
+
1750
+ if filename.endswith(".safetensors"):
1751
+ from safetensors.torch import load_file as safetensors_load
1752
+
1753
+ return safetensors_load(filename, device=device)
1754
+
1755
+ return torch.load(filename, map_location=device)
1756
+
1757
+
1758
+ class _TorchProfiler:
1759
+ def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
1760
+ self.profiler = profiler
1761
+ self.max_step = max_step
1762
+ self.entered = False
1763
+
1764
+ def __enter__(self):
1765
+ self.profiler.__enter__()
1766
+ self.entered = True
1767
+
1768
+ def __exit__(self, exc_type, exc_val, exc_tb):
1769
+ if not self.entered:
1770
+ return
1771
+ self.entered = False
1772
+ self.profiler.__exit__(exc_type, exc_val, exc_tb)
1773
+
1774
+ if exc_type is None:
1775
+ print(
1776
+ "Torch profiling finished, exporting Chrome trace to torch_profile.json,"
1777
+ " memory timeline to torch_memory_profile.html...",
1778
+ file=log.v2,
1779
+ )
1780
+ self.profiler.export_chrome_trace("torch_profile.json")
1781
+ self.profiler.export_memory_timeline("torch_memory_profile.html")
1782
+
1783
+ print("Exiting program after Torch profiling.", file=log.v2)
1784
+ sys.exit(0)
1785
+
1786
+ def step(self):
1787
+ """step"""
1788
+ self.profiler.step()
1789
+ if self.max_step is not None and self.profiler.step_num > self.max_step:
1790
+ print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
1791
+ self.profiler.stop()
1792
+ self.__exit__(None, None, None)
1793
+
1794
+
1795
+ def _opt_torch_profiler_from_opts(
1796
+ opts: Union[None, int, bool, str, Dict[str, Any]],
1797
+ ) -> Optional[_TorchProfiler]:
1798
+ if isinstance(opts, str):
1799
+ from returnn.util.basic import to_bool
1800
+
1801
+ opts = to_bool(opts)
1802
+
1803
+ if opts is None:
1804
+ return None
1805
+ elif isinstance(opts, (bool, int)):
1806
+ if not opts:
1807
+ return None
1808
+ opts = {}
1809
+ elif isinstance(opts, dict):
1810
+ opts = opts.copy()
1811
+ else:
1812
+ raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
1813
+
1814
+ from torch.profiler import profile, ProfilerActivity, schedule
1815
+
1816
+ print("Using Torch profiler...", file=log.v2)
1817
+
1818
+ prof_max_step = None
1819
+
1820
+ if "activities" not in opts:
1821
+ activities = [ProfilerActivity.CPU]
1822
+ if torch.cuda.is_available():
1823
+ activities += [ProfilerActivity.CUDA]
1824
+ elif torch.xpu.is_available():
1825
+ activities += [ProfilerActivity.XPU]
1826
+ opts["activities"] = activities
1827
+
1828
+ opts.setdefault("profile_memory", True)
1829
+ opts.setdefault("record_shapes", True)
1830
+ opts.setdefault("with_stack", True)
1831
+ opts.setdefault("with_flops", True)
1832
+ # Note: active*repeat are the steps we actually profile.
1833
+ opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
1834
+
1835
+ if isinstance(opts["schedule"], dict):
1836
+ schedule_opts: Dict[str, Any] = opts["schedule"]
1837
+ schedule_opts = schedule_opts.copy()
1838
+ schedule_opts.setdefault("repeat", 0)
1839
+ schedule_opts.setdefault("skip_first", 0)
1840
+ schedule_opts.setdefault("skip_first_wait", 0)
1841
+ opts["schedule"] = schedule(**schedule_opts)
1842
+
1843
+ if schedule_opts["repeat"] > 0:
1844
+ prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
1845
+ "repeat"
1846
+ ]
1847
+ prof_max_step += schedule_opts["skip_first"]
1848
+ if schedule_opts["skip_first_wait"] != 0:
1849
+ prof_max_step -= schedule_opts["wait"]
1850
+ print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
1851
+
1852
+ prof = profile(**opts)
1853
+ return _TorchProfiler(prof, prof_max_step)
@@ -275,7 +275,7 @@ class TorchBackend(Backend[torch.Tensor]):
275
275
  :return: tensor
276
276
  """
277
277
  assert len(dims) >= 2
278
- first_axis = min(source.dims.index(d) for d in dims)
278
+ first_axis = min([source.dims.index(d) for d in dims])
279
279
  pre_dims = source.dims[:first_axis]
280
280
  post_dims = [d for d in source.dims if d not in dims and d not in pre_dims]
281
281
  source = source.copy_transpose(tuple(pre_dims) + tuple(dims) + tuple(post_dims), allow_int=False)
@@ -884,7 +884,7 @@ class TorchBackend(Backend[torch.Tensor]):
884
884
  :param perm: e.g. [0, 2, 1]
885
885
  :return: permuted (transposed) raw tensor; wraps torch.permute
886
886
  """
887
- if all(p == i for i, p in enumerate(perm)):
887
+ if all([p == i for i, p in enumerate(perm)]):
888
888
  return raw_tensor
889
889
  return torch.permute(raw_tensor, tuple(perm))
890
890
 
@@ -1166,20 +1166,29 @@ class TorchBackend(Backend[torch.Tensor]):
1166
1166
  if start is None:
1167
1167
  start = 0
1168
1168
  if isinstance(size, Dim):
1169
+ assert end is None
1169
1170
  size = size.get_dim_value()
1170
1171
  elif isinstance(size, Tensor):
1172
+ assert end is None
1171
1173
  assert size.dims == () # scalar
1172
1174
  size = size.raw_tensor
1173
- if size is not None:
1174
- assert end is None
1175
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1176
- else:
1175
+ elif isinstance(size, int):
1176
+ pass
1177
+ elif size is None:
1177
1178
  if isinstance(end, Tensor):
1178
1179
  assert end.dims == ()
1179
1180
  end = end.raw_tensor
1180
- if end is None:
1181
+ elif isinstance(end, int):
1182
+ if end < 0:
1183
+ end += axis.get_dim_value()
1184
+ elif end is None:
1181
1185
  end = axis.get_dim_value()
1182
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=end - start)
1186
+ else:
1187
+ raise TypeError(f"slice: unsupported type for end: {type(end)}")
1188
+ size = end - start
1189
+ else:
1190
+ raise TypeError(f"slice: unsupported type for size: {type(size)}")
1191
+ out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1183
1192
  return out
1184
1193
 
1185
1194
  @staticmethod
@@ -1352,12 +1361,24 @@ class TorchBackend(Backend[torch.Tensor]):
1352
1361
  a_dims = a.dims
1353
1362
  b_dims = b.dims
1354
1363
 
1355
- assert all(dim in a_dims for dim in reduce), (
1356
- f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})"
1357
- )
1358
- assert all(dim in b_dims for dim in reduce), (
1359
- f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})"
1360
- )
1364
+ if not all(dim in a_dims for dim in reduce) or not all(dim in b_dims for dim in reduce):
1365
+ # revert to the generic einsum implementation
1366
+ assert all(dim in a_dims + b_dims for dim in reduce), "Some reduce Dims not in a or b."
1367
+ result_dims = [dim for dim in a_dims if dim not in reduce] + [
1368
+ dim for dim in b_dims if dim not in reduce and dim not in a_dims
1369
+ ]
1370
+ map_to_letter = {}
1371
+ for dim in a_dims + b_dims:
1372
+ if dim not in map_to_letter:
1373
+ map_to_letter[dim] = chr(97 + len(map_to_letter)) # 'a', 'b', 'c', ...
1374
+ a_subscript = "".join(map_to_letter[dim] for dim in a_dims)
1375
+ b_subscript = "".join(map_to_letter[dim] for dim in b_dims)
1376
+ out_subscript = "".join(map_to_letter[dim] for dim in result_dims)
1377
+ raw_result = torch.einsum(f"{a_subscript},{b_subscript}->{out_subscript}", a.raw_tensor, b.raw_tensor)
1378
+ result_tensor = Tensor(
1379
+ "einsum", dims=result_dims, raw_tensor=raw_result, dtype=TorchBackend.get_dtype_name_raw(raw_result)
1380
+ )
1381
+ return result_tensor
1361
1382
 
1362
1383
  if len(reduce) > 1:
1363
1384
  reduce = list(reduce)
@@ -1767,6 +1788,9 @@ class TorchBackend(Backend[torch.Tensor]):
1767
1788
  remaining_dims = [d for d in tensor.dims if d not in mask.dims]
1768
1789
  tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
1769
1790
  in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
1791
+ if any([in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)]):
1792
+ # unbroadcast
1793
+ in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
1770
1794
  if mask.raw_tensor.device.type == "meta":
1771
1795
  # This is not supported, but also, we would anyway not know the out shape.
1772
1796
  # However, instead of erroring, just assume some dummy mask.
@@ -1920,7 +1944,7 @@ class TorchBackend(Backend[torch.Tensor]):
1920
1944
  if not out_spatial_dims:
1921
1945
  out_spatial_dims = rf.make_conv_out_spatial_dims(
1922
1946
  in_spatial_dims=in_spatial_dims,
1923
- filter_size=[d.dimension for d in filter_size],
1947
+ filter_size=filter_size,
1924
1948
  strides=strides or 1,
1925
1949
  dilation_rate=dilation_rate or 1,
1926
1950
  padding=padding,
@@ -2033,6 +2057,104 @@ class TorchBackend(Backend[torch.Tensor]):
2033
2057
  out.feature_dim = out_dim
2034
2058
  return out, out_spatial_dims
2035
2059
 
2060
+ # noinspection PyShadowingBuiltins
2061
+ @staticmethod
2062
+ def transposed_conv(
2063
+ source: Tensor,
2064
+ *,
2065
+ in_dim: Dim,
2066
+ out_dim: Dim,
2067
+ in_spatial_dims: Sequence[Dim],
2068
+ out_spatial_dims: Optional[Sequence[Dim]] = None,
2069
+ filter: Tensor,
2070
+ filter_size: Sequence[Dim],
2071
+ padding: str,
2072
+ remove_padding: Union[Sequence[int], int] = 0,
2073
+ output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
2074
+ strides: Optional[Sequence[int]] = None,
2075
+ bias: Optional[Tensor] = None,
2076
+ ) -> Tuple[Tensor, Sequence[Dim]]:
2077
+ """transposed convolution"""
2078
+ if not out_spatial_dims:
2079
+ out_spatial_dims = rf.make_transposed_conv_out_spatial_dims(
2080
+ in_spatial_dims=in_spatial_dims,
2081
+ filter_size=filter_size,
2082
+ strides=strides,
2083
+ padding=padding,
2084
+ output_padding=output_padding,
2085
+ )
2086
+ assert remove_padding == 0 # not implemented yet otherwise...
2087
+ if strides is None:
2088
+ strides = [fs.dimension for fs in filter_size]
2089
+ filter_dims = (in_dim, out_dim) + tuple(filter_size)
2090
+ filter = filter.copy_transpose(filter_dims)
2091
+ batch_dims = [d for d in source.dims if d not in (in_dim,) + tuple(in_spatial_dims)]
2092
+ # Torch conv expects (N,C,<spatial dims>) as shape.
2093
+ source = source.copy_transpose(batch_dims + [in_dim] + list(in_spatial_dims))
2094
+ if len(batch_dims) == 1:
2095
+ src_raw = source.raw_tensor
2096
+ else:
2097
+ src_raw = torch.reshape(
2098
+ source.raw_tensor,
2099
+ # potentially merge batch dims all together
2100
+ [-1, in_dim.get_dim_value()] + [d.get_dim_value() for d in in_spatial_dims],
2101
+ )
2102
+ if padding == "same":
2103
+ raise NotImplementedError("transposed_conv with padding='same' not implemented")
2104
+ if padding == "valid":
2105
+ padding_val = 0
2106
+ else:
2107
+ raise ValueError(f"invalid padding {padding!r}, expected 'same' or 'valid'")
2108
+ if len(filter_size) == 1:
2109
+ out_raw = torch.nn.functional.conv_transpose1d(
2110
+ src_raw,
2111
+ weight=filter.raw_tensor,
2112
+ bias=bias.raw_tensor if bias is not None else None,
2113
+ stride=strides,
2114
+ padding=padding_val,
2115
+ output_padding=output_padding or 0,
2116
+ )
2117
+ elif len(filter_size) == 2:
2118
+ out_raw = torch.nn.functional.conv_transpose2d(
2119
+ src_raw,
2120
+ weight=filter.raw_tensor,
2121
+ bias=bias.raw_tensor if bias is not None else None,
2122
+ stride=strides,
2123
+ padding=padding_val,
2124
+ output_padding=output_padding or 0,
2125
+ )
2126
+ elif len(filter_size) == 3:
2127
+ out_raw = torch.nn.functional.conv_transpose3d(
2128
+ src_raw,
2129
+ weight=filter.raw_tensor,
2130
+ bias=bias.raw_tensor if bias is not None else None,
2131
+ stride=strides,
2132
+ padding=padding_val,
2133
+ output_padding=output_padding or 0,
2134
+ )
2135
+ else:
2136
+ raise ValueError(f"invalid number of filter dims {filter_size}, expected 1, 2, or 3")
2137
+ if remove_padding:
2138
+ if isinstance(remove_padding, int):
2139
+ remove_padding = [remove_padding] * len(out_spatial_dims)
2140
+ assert len(remove_padding) == len(out_spatial_dims)
2141
+ slices = [slice(None)] * out_raw.ndim
2142
+ for i, pad in enumerate(remove_padding):
2143
+ if pad > 0:
2144
+ slices[2 + i] = slice(0, -pad)
2145
+ out_raw = out_raw[tuple(slices)]
2146
+ out = Tensor(
2147
+ "transposed_conv",
2148
+ dims=batch_dims + [out_dim] + list(out_spatial_dims),
2149
+ dtype=TorchBackend.get_dtype_name_raw(out_raw),
2150
+ )
2151
+ if len(batch_dims) == 1:
2152
+ out.raw_tensor = out_raw
2153
+ else:
2154
+ out.raw_tensor = torch.reshape(out_raw, [d.get_dim_value() for d in out.dims])
2155
+ out.feature_dim = out_dim
2156
+ return out, out_spatial_dims
2157
+
2036
2158
  @staticmethod
2037
2159
  def pool(
2038
2160
  source: Tensor,
@@ -136,6 +136,15 @@ class RFModuleAsPTModule(torch.nn.Module):
136
136
  def _get_name(self):
137
137
  return self._rf_module.__class__.__name__ + "[RF→PT]"
138
138
 
139
+ def __repr__(self) -> str:
140
+ """
141
+ Return a custom repr for Sequential/ModuleList that compresses repeated module representations if possible,
142
+ otherwise fallback to default behavior.
143
+ """
144
+ if _can_use_compact_repr(self):
145
+ return _repr_compact(self)
146
+ return super().__repr__()
147
+
139
148
  @property
140
149
  def rf_module(self) -> rf.Module:
141
150
  """RF module"""
@@ -193,3 +202,55 @@ class RFModuleAsPTModule(torch.nn.Module):
193
202
  # See similar logic in torch.nn.Module._apply.
194
203
  pt_param = torch.nn.Parameter(tensor, tensor.requires_grad)
195
204
  rf_param.raw_tensor = pt_param
205
+
206
+
207
+ def _can_use_compact_repr(self: RFModuleAsPTModule) -> bool:
208
+ return list(self._modules.keys()) == [str(i) for i in range(len(self._modules))]
209
+
210
+
211
+ def _repr_compact(self: RFModuleAsPTModule) -> str:
212
+ """
213
+ Return a custom repr for Sequential/ModuleList that compresses repeated module representations.
214
+ Code copied and adapted from torch.nn.ModuleList.__repr__.
215
+ """
216
+ list_of_reprs = [repr(item) for item in self._modules.values()]
217
+ if len(list_of_reprs) == 0:
218
+ return self._get_name() + "()"
219
+
220
+ start_end_indices = [[0, 0]]
221
+ repeated_blocks = [list_of_reprs[0]]
222
+ for i, r in enumerate(list_of_reprs[1:], 1):
223
+ if r == repeated_blocks[-1]:
224
+ start_end_indices[-1][1] += 1
225
+ continue
226
+
227
+ start_end_indices.append([i, i])
228
+ repeated_blocks.append(r)
229
+
230
+ lines = []
231
+ main_str = self._get_name() + "("
232
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
233
+ local_repr = f"({start_id}): {b}" # default repr
234
+
235
+ if start_id != end_id:
236
+ n = end_id - start_id + 1
237
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
238
+
239
+ local_repr = _add_indent(local_repr, 2)
240
+ lines.append(local_repr)
241
+
242
+ main_str += "\n " + "\n ".join(lines) + "\n"
243
+ main_str += ")"
244
+ return main_str
245
+
246
+
247
+ def _add_indent(s_: str, num_spaces: int) -> str:
248
+ s = s_.split("\n")
249
+ # don't do anything for single-line stuff
250
+ if len(s) == 1:
251
+ return s_
252
+ first = s.pop(0)
253
+ s = [(num_spaces * " ") + line for line in s]
254
+ s = "\n".join(s)
255
+ s = first + "\n" + s
256
+ return s
@@ -0,0 +1,106 @@
1
+ """
2
+ Helpers to improve torch.compile on RF code.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import Any, Iterable, List, Tuple
7
+
8
+ import os
9
+ from returnn.tensor import Tensor, Dim
10
+
11
+ # noinspection PyProtectedMember
12
+ from returnn.frontend import _native
13
+
14
+ _is_set_up = False
15
+
16
+
17
+ def setup():
18
+ """
19
+ Set up the torch.compile helpers for RF code, also including :class:`Tensor` and :class:`Dim`.
20
+ """
21
+
22
+ global _is_set_up
23
+ if _is_set_up:
24
+ return
25
+ _is_set_up = True # only try once
26
+
27
+ assert not _native.is_set_up(), "Call this setup() as early as possible."
28
+ _native.set_enabled(False)
29
+
30
+ # We have lots of dynamic shapes.
31
+ os.environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
32
+
33
+ # noinspection PyProtectedMember
34
+ from torch.utils._pytree import register_pytree_node
35
+
36
+ register_pytree_node(Tensor, _tensor_flatten, _tensor_unflatten)
37
+ register_pytree_node(Dim, _dim_flatten, _dim_unflatten)
38
+
39
+ Dim.get_dim_value = _dim_get_dim_value
40
+
41
+
42
+ def _tensor_flatten(t: Tensor) -> Tuple[List[Any], Any]:
43
+ """
44
+ Flatten the tensor for PyTree.
45
+ """
46
+ return [t.raw_tensor, t.dims, t.sparse_dim], [
47
+ t.name,
48
+ t.dtype,
49
+ t.version,
50
+ t.feature_dim_axis_or_unspecified,
51
+ t.time_dim_axis_or_unspecified,
52
+ ]
53
+
54
+
55
+ def _tensor_unflatten(values: Iterable[Any], metadata: Any) -> Tensor:
56
+ """
57
+ Unflatten the tensor from PyTree.
58
+ """
59
+ raw_tensor, dims, sparse_dim = values
60
+ name, dtype, version, feature_dim_axis, time_dim_axis = metadata
61
+ return Tensor(
62
+ name=name,
63
+ dims=dims,
64
+ dtype=dtype,
65
+ sparse_dim=sparse_dim,
66
+ feature_dim_axis=feature_dim_axis,
67
+ time_dim_axis=time_dim_axis,
68
+ raw_tensor=raw_tensor,
69
+ version=version,
70
+ )
71
+
72
+
73
+ def _dim_flatten(d: Dim) -> Tuple[List[Any], Any]:
74
+ """
75
+ Flatten the dim for PyTree.
76
+ """
77
+ return [d.dyn_size_ext], [d.name, d.dimension, d.size]
78
+
79
+
80
+ def _dim_unflatten(values: Iterable[Any], metadata: Any) -> Dim:
81
+ """
82
+ Unflatten the dim from PyTree.
83
+ """
84
+ (dyn_size_ext,) = values
85
+ name, dimension, size = metadata
86
+ # TODO this creates a new instance... this is maybe wrong?
87
+ return Dim(name=name, dimension=dimension, size=size, dyn_size_ext=dyn_size_ext)
88
+
89
+
90
+ def _dim_get_dim_value(self: Dim) -> int:
91
+ """
92
+ Infers the dim this axis should have if unbroadcasted.
93
+ If `self.src_data` has a placeholder, will use the shape from there.
94
+ Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
95
+
96
+ :return: max(size or dyn_size)
97
+ """
98
+ res = self.get_dim_value_tensor()
99
+ if isinstance(res, Tensor):
100
+ assert res.dims == ()
101
+ assert res.raw_tensor is not None
102
+ # Specifically PyTorch would then treat it as a SymInt in torch.compile,
103
+ # which is important to have for some torch functions (e.g. torch.tile and others).
104
+ return int(res.raw_tensor)
105
+ assert isinstance(res, int)
106
+ return res
@@ -71,7 +71,13 @@ def help_on_torch_exception(
71
71
  if not count_frames:
72
72
  exc_ext.append("(No module call frames.)")
73
73
 
74
- if len(exc.args) == 1 and isinstance(exc.args[0], str) and not always_direct_print:
74
+ if (
75
+ # KeyError formatting would be wrong, showing `KeyError: "enc_spatial_dim\n\nStep idx: 0\..."`
76
+ not isinstance(exc, KeyError)
77
+ and len(exc.args) == 1
78
+ and isinstance(exc.args[0], str)
79
+ and not always_direct_print
80
+ ):
75
81
  exc.args = ("\n".join([exc.args[0], ""] + exc_ext),)
76
82
  else:
77
83
  for msg in exc_ext:
returnn/util/basic.py CHANGED
@@ -365,12 +365,9 @@ def get_checkpoint_filepattern(filepath):
365
365
  :return: CheckpointLoader compatible filepattern
366
366
  :rtype: str
367
367
  """
368
- if filepath.endswith(".meta"):
369
- return filepath[: -len(".meta")]
370
- elif filepath.endswith(".index"):
371
- return filepath[: -len(".index")]
372
- elif filepath.endswith(".pt"):
373
- return filepath[: -len(".pt")]
368
+ for ext in [".meta", ".index", ".pt"]:
369
+ if filepath.endswith(ext):
370
+ return filepath[: -len(ext)]
374
371
  return filepath
375
372
 
376
373
 
@@ -3819,6 +3816,8 @@ def should_write_to_disk(config):
3819
3816
  return False
3820
3817
  if config.is_true("dry_run"):
3821
3818
  return False
3819
+ if config.is_true("torch_profile"):
3820
+ return False
3822
3821
  return True
3823
3822
 
3824
3823