returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/config.py +1 -1
  4. returnn/datasets/basic.py +29 -13
  5. returnn/datasets/distrib_files.py +61 -3
  6. returnn/datasets/generating.py +12 -21
  7. returnn/datasets/huggingface.py +434 -0
  8. returnn/datasets/lm.py +20 -0
  9. returnn/datasets/meta.py +179 -60
  10. returnn/datasets/multi_proc.py +1 -1
  11. returnn/datasets/postprocessing.py +597 -108
  12. returnn/datasets/text_dict.py +1 -1
  13. returnn/datasets/util/vocabulary.py +90 -0
  14. returnn/frontend/_backend.py +7 -0
  15. returnn/frontend/array_.py +54 -1
  16. returnn/frontend/attention.py +54 -20
  17. returnn/frontend/conv.py +273 -54
  18. returnn/frontend/decoder/transformer.py +36 -17
  19. returnn/frontend/encoder/conformer.py +1 -0
  20. returnn/frontend/encoder/transformer.py +2 -0
  21. returnn/frontend/loss.py +40 -1
  22. returnn/frontend/module.py +8 -1
  23. returnn/frontend/nested.py +9 -0
  24. returnn/native_op.cpp +80 -0
  25. returnn/sprint/cache.py +12 -13
  26. returnn/tensor/_dim_extra.py +51 -29
  27. returnn/tensor/_tensor_extra.py +6 -1
  28. returnn/tensor/utils.py +7 -4
  29. returnn/tf/frontend_layers/_backend.py +11 -2
  30. returnn/tf/frontend_low_level/_backend.py +15 -0
  31. returnn/tf/layers/basic.py +16 -38
  32. returnn/tf/native_op.py +11 -58
  33. returnn/tf/network.py +1 -1
  34. returnn/tf/util/basic.py +19 -0
  35. returnn/torch/data/returnn_dataset_wrapper.py +9 -3
  36. returnn/torch/engine.py +67 -2
  37. returnn/torch/frontend/_backend.py +119 -7
  38. returnn/torch/util/diagnose_gpu.py +65 -31
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/util/basic.py +6 -7
  41. returnn/util/better_exchook.py +4 -0
  42. returnn/util/collect_outputs_dict.py +79 -0
  43. returnn/util/debug.py +11 -2
  44. returnn/util/file_cache.py +42 -4
  45. returnn/util/task_system.py +1 -1
  46. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
  47. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
  48. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
  49. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
  50. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
returnn/tf/util/basic.py CHANGED
@@ -2784,6 +2784,10 @@ class CudaEnv:
2784
2784
  self.cuda_path = None
2785
2785
  if self.verbose_find_cuda:
2786
2786
  print("CUDA disabled via env DISABLE_CUDA.")
2787
+ elif os.environ.get("CUDA_VISIBLE_DEVICES", None) in ["", "-1"]:
2788
+ self.cuda_path = None
2789
+ if self.verbose_find_cuda:
2790
+ print(f"CUDA disabled via env CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']!r}.")
2787
2791
  else:
2788
2792
  self.cuda_path = self._find_cuda_path()
2789
2793
  if self.verbose_find_cuda:
@@ -3020,6 +3024,21 @@ class OpCodeCompiler(NativeCodeCompiler):
3020
3024
  ld_flags += tf.sysconfig.get_link_flags()
3021
3025
  elif have_min_tf_version((1, 4)):
3022
3026
  ld_flags += ["-L%s" % tf.sysconfig.get_lib(), "-ltensorflow_framework"]
3027
+ if have_min_tf_version((2, 20)):
3028
+ # TF 2.20 removed TF_MAJOR_VERSION and co from version.h,
3029
+ # and one is supposed to define these macros externally.
3030
+ # Also, release_version.h was added to define TF_VERSION_STRING based on this (if needed).
3031
+ # https://github.com/tensorflow/tensorflow/commit/c8f0e0620e5678d0f165a07e64114024a966ab7f
3032
+ major, minor, patch = tf.__version__.split(".", 2)
3033
+ patch, suffix = patch.split("-", 1) if "-" in patch else (patch, "")
3034
+ c_macro_defines.update(
3035
+ {
3036
+ "TF_MAJOR_VERSION": major,
3037
+ "TF_MINOR_VERSION": minor,
3038
+ "TF_PATCH_VERSION": patch,
3039
+ "TF_VERSION_SUFFIX": suffix,
3040
+ }
3041
+ )
3023
3042
  use_cxx11_abi = getattr(getattr(tf, "sysconfig", tf), "CXX11_ABI_FLAG", getattr(tf, "CXX11_ABI_FLAG", False))
3024
3043
  super(OpCodeCompiler, self).__init__(
3025
3044
  include_paths=include_paths,
@@ -20,12 +20,18 @@ ResetCallbackT = Callable[[], None]
20
20
  class ReturnnDatasetResetDefaultEpochCounterCallback:
21
21
  """
22
22
  Default for reset_callback.
23
- Has an internal counter for the epoch, starting at epoch 1 (RETURNN convention).
23
+ Has an internal counter for the epoch, starting by default at epoch 1 (RETURNN convention).
24
24
  """
25
25
 
26
- def __init__(self, dataset: ReturnnDataset):
26
+ def __init__(self, dataset: ReturnnDataset, *, epoch0: int = 0):
27
+ """
28
+ :param dataset: RETURNN dataset.
29
+ :param epoch0: Epoch from which the dataset sequence ordering should start.
30
+ It will actually be epoch0+1 for the first epoch, since :func:`__call__` will increment it.
31
+ By default 0 since next :func:`__call__` will increment, thus we start at epoch 1.
32
+ """
27
33
  self.dataset = dataset
28
- self.epoch = 0 # next __call__ will increment, thus we start at epoch 1
34
+ self.epoch = epoch0
29
35
 
30
36
  def __call__(self):
31
37
  # dataset is likely a copy of the original dataset, either in the main process or in a worker process
returnn/torch/engine.py CHANGED
@@ -134,6 +134,14 @@ class Engine(EngineBase):
134
134
  self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
135
135
  self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)
136
136
 
137
+ if config.bool("use_tensorboard", False):
138
+ from torch.utils.tensorboard import SummaryWriter
139
+
140
+ self._tensorboard_writer = SummaryWriter()
141
+ self._tensorboard_opts = config.typed_value("tensorboard_opts", {})
142
+ else:
143
+ self._tensorboard_writer = None
144
+
137
145
  default_float_dtype = config.value("default_float_dtype", None)
138
146
  if default_float_dtype is not None:
139
147
  assert isinstance(default_float_dtype, str)
@@ -257,6 +265,9 @@ class Engine(EngineBase):
257
265
  self.init_train_epoch()
258
266
  self.train_epoch()
259
267
 
268
+ if self._tensorboard_writer:
269
+ self._tensorboard_writer.close()
270
+
260
271
  print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3)
261
272
 
262
273
  def init_train_epoch(self):
@@ -513,6 +524,18 @@ class Engine(EngineBase):
513
524
  batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
514
525
  log_memory_usage_device=self._device if self._log_memory_usage else None,
515
526
  )
527
+ if (
528
+ self._tensorboard_writer
529
+ and self.global_train_step % self._tensorboard_opts.get("log_every_n_train_steps", 100) == 0
530
+ ):
531
+ # write losses/errors to tensorboard
532
+ for key, val in eval_info.items():
533
+ self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
534
+ self._tensorboard_writer.add_scalar(
535
+ "train/learning_rate",
536
+ self._updater.get_effective_learning_rate(),
537
+ global_step=self.global_train_step,
538
+ )
516
539
 
517
540
  if self._stop_on_nonfinite_train_score:
518
541
  if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()):
@@ -702,12 +725,20 @@ class Engine(EngineBase):
702
725
  start_elapsed=step_end_time - eval_start_time,
703
726
  log_memory_usage_device=self._device if self._log_memory_usage else None,
704
727
  )
728
+
705
729
  step_idx += 1
706
730
 
707
731
  assert step_idx > 0, f"No data in dataset {dataset_name!r}."
708
732
  accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
709
733
  accumulated_losses_dict = self._maybe_extend_losses_info(accumulated_losses_dict)
710
734
 
735
+ if self._tensorboard_writer:
736
+ # write losses/errors to tensorboard
737
+ for key, val in accumulated_losses_dict.items():
738
+ self._tensorboard_writer.add_scalar(
739
+ f"{dataset_name}/{key}", val, global_step=self.global_train_step
740
+ )
741
+
711
742
  self.learning_rate_control.set_epoch_error(
712
743
  self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()}
713
744
  )
@@ -899,7 +930,7 @@ class Engine(EngineBase):
899
930
  if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
900
931
  filename = model_epoch_filename
901
932
  print("Load model %s" % (filename,), file=log.v4)
902
- checkpoint_state = torch.load(filename, map_location=self._device)
933
+ checkpoint_state = _torch_load(filename, device=self._device)
903
934
  if epoch is None:
904
935
  epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
905
936
  step = checkpoint_state.get("step", 1)
@@ -999,7 +1030,7 @@ class Engine(EngineBase):
999
1030
  print("(No relevant parameters matching.)", file=log.v3)
1000
1031
  continue
1001
1032
  print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
1002
- preload_model_state = torch.load(opts["filename"], map_location=self._device)
1033
+ preload_model_state = _torch_load(opts["filename"], device=self._device)
1003
1034
  if opts.get("checkpoint_key", "model") is not None:
1004
1035
  # This can be used if an external checkpoint saves a checkpoint a different structure that just the
1005
1036
  # model state dict. E.g., if a checkpoint is created using
@@ -1032,6 +1063,28 @@ class Engine(EngineBase):
1032
1063
  preload_model_state_keys = set(preload_model_state.keys())
1033
1064
  loaded_state_keys.update(preload_model_state.keys())
1034
1065
  missing_keys.difference_update(preload_model_state.keys())
1066
+
1067
+ custom_missing_load_func = opts.get("custom_missing_load_func")
1068
+ if custom_missing_load_func:
1069
+ custom_missing_vars_map = {}
1070
+ for var_name in missing_keys_preload:
1071
+ var_shape = self._pt_model.state_dict()[var_name].shape
1072
+ var_val = custom_missing_load_func(
1073
+ name=var_name,
1074
+ shape=var_shape,
1075
+ preload_model_state=preload_model_state,
1076
+ **util.get_fwd_compat_kwargs(),
1077
+ )
1078
+ if var_val is not None:
1079
+ assert var_val.shape == var_shape
1080
+ custom_missing_vars_map[var_name] = var_val
1081
+ preload_model_state.update(custom_missing_vars_map)
1082
+ missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
1083
+ preload_model_state, strict=False
1084
+ )
1085
+ loaded_state_keys.update(preload_model_state.keys())
1086
+ missing_keys.difference_update(preload_model_state.keys())
1087
+
1035
1088
  del preload_model_state
1036
1089
  gc.collect()
1037
1090
 
@@ -1669,3 +1722,15 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1669
1722
  p=p,
1670
1723
  ).item()
1671
1724
  )
1725
+
1726
+
1727
+ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
1728
+ # Might resolve PtCheckpoint or Sisyphus Path objects or so.
1729
+ filename = os.fspath(filename)
1730
+
1731
+ if filename.endswith(".safetensors"):
1732
+ from safetensors.torch import load_file as safetensors_load
1733
+
1734
+ return safetensors_load(filename, device=device)
1735
+
1736
+ return torch.load(filename, map_location=device)
@@ -1166,20 +1166,29 @@ class TorchBackend(Backend[torch.Tensor]):
1166
1166
  if start is None:
1167
1167
  start = 0
1168
1168
  if isinstance(size, Dim):
1169
+ assert end is None
1169
1170
  size = size.get_dim_value()
1170
1171
  elif isinstance(size, Tensor):
1172
+ assert end is None
1171
1173
  assert size.dims == () # scalar
1172
1174
  size = size.raw_tensor
1173
- if size is not None:
1174
- assert end is None
1175
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1176
- else:
1175
+ elif isinstance(size, int):
1176
+ pass
1177
+ elif size is None:
1177
1178
  if isinstance(end, Tensor):
1178
1179
  assert end.dims == ()
1179
1180
  end = end.raw_tensor
1180
- if end is None:
1181
+ elif isinstance(end, int):
1182
+ if end < 0:
1183
+ end += axis.get_dim_value()
1184
+ elif end is None:
1181
1185
  end = axis.get_dim_value()
1182
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=end - start)
1186
+ else:
1187
+ raise TypeError(f"slice: unsupported type for end: {type(end)}")
1188
+ size = end - start
1189
+ else:
1190
+ raise TypeError(f"slice: unsupported type for size: {type(size)}")
1191
+ out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1183
1192
  return out
1184
1193
 
1185
1194
  @staticmethod
@@ -1572,6 +1581,7 @@ class TorchBackend(Backend[torch.Tensor]):
1572
1581
  indices_out_raw = indices_raw % a.dimension
1573
1582
  indices_raw = indices_raw // a.dimension
1574
1583
  indices = values.copy_template(name=f"top_k_indices_{a.name or i}")
1584
+ indices.feature_dim = None
1575
1585
  indices.dtype = TorchBackend.get_dtype_name_raw(indices_out_raw)
1576
1586
  indices.sparse_dim = a
1577
1587
  indices.raw_tensor = indices_out_raw
@@ -1588,6 +1598,7 @@ class TorchBackend(Backend[torch.Tensor]):
1588
1598
  values = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_values")
1589
1599
  values.raw_tensor = values_raw
1590
1600
  indices = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_indices")
1601
+ indices.feature_dim = None
1591
1602
  indices.dtype = TorchBackend.get_dtype_name_raw(indices_raw)
1592
1603
  indices.sparse_dim = axis
1593
1604
  indices.raw_tensor = indices_raw
@@ -1639,6 +1650,8 @@ class TorchBackend(Backend[torch.Tensor]):
1639
1650
  name=f"random_{distribution}", dims=dims, dtype=dtype, sparse_dim=sparse_dim, feature_dim=feature_dim
1640
1651
  )
1641
1652
  out.raw_tensor = torch.empty(shape, dtype=dtype_, device=device or rf.get_default_device())
1653
+ if out.raw_tensor.device.type == "meta":
1654
+ return out # nothing more to do
1642
1655
  assert explicit_state is None # not implemented otherwise
1643
1656
  generator = None # using the global default from PT
1644
1657
  assert isinstance(static, bool)
@@ -1787,6 +1800,7 @@ class TorchBackend(Backend[torch.Tensor]):
1787
1800
  dims=(out_dim,) + tuple(remaining_dims),
1788
1801
  dtype=tensor.dtype,
1789
1802
  sparse_dim=tensor.sparse_dim,
1803
+ feature_dim=tensor.feature_dim,
1790
1804
  raw_tensor=out_raw,
1791
1805
  )
1792
1806
  return out, out_dim
@@ -1915,7 +1929,7 @@ class TorchBackend(Backend[torch.Tensor]):
1915
1929
  if not out_spatial_dims:
1916
1930
  out_spatial_dims = rf.make_conv_out_spatial_dims(
1917
1931
  in_spatial_dims=in_spatial_dims,
1918
- filter_size=[d.dimension for d in filter_size],
1932
+ filter_size=filter_size,
1919
1933
  strides=strides or 1,
1920
1934
  dilation_rate=dilation_rate or 1,
1921
1935
  padding=padding,
@@ -2028,6 +2042,104 @@ class TorchBackend(Backend[torch.Tensor]):
2028
2042
  out.feature_dim = out_dim
2029
2043
  return out, out_spatial_dims
2030
2044
 
2045
+ # noinspection PyShadowingBuiltins
2046
+ @staticmethod
2047
+ def transposed_conv(
2048
+ source: Tensor,
2049
+ *,
2050
+ in_dim: Dim,
2051
+ out_dim: Dim,
2052
+ in_spatial_dims: Sequence[Dim],
2053
+ out_spatial_dims: Optional[Sequence[Dim]] = None,
2054
+ filter: Tensor,
2055
+ filter_size: Sequence[Dim],
2056
+ padding: str,
2057
+ remove_padding: Union[Sequence[int], int] = 0,
2058
+ output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
2059
+ strides: Optional[Sequence[int]] = None,
2060
+ bias: Optional[Tensor] = None,
2061
+ ) -> Tuple[Tensor, Sequence[Dim]]:
2062
+ """transposed convolution"""
2063
+ if not out_spatial_dims:
2064
+ out_spatial_dims = rf.make_transposed_conv_out_spatial_dims(
2065
+ in_spatial_dims=in_spatial_dims,
2066
+ filter_size=filter_size,
2067
+ strides=strides,
2068
+ padding=padding,
2069
+ output_padding=output_padding,
2070
+ )
2071
+ assert remove_padding == 0 # not implemented yet otherwise...
2072
+ if strides is None:
2073
+ strides = [fs.dimension for fs in filter_size]
2074
+ filter_dims = (in_dim, out_dim) + tuple(filter_size)
2075
+ filter = filter.copy_transpose(filter_dims)
2076
+ batch_dims = [d for d in source.dims if d not in (in_dim,) + tuple(in_spatial_dims)]
2077
+ # Torch conv expects (N,C,<spatial dims>) as shape.
2078
+ source = source.copy_transpose(batch_dims + [in_dim] + list(in_spatial_dims))
2079
+ if len(batch_dims) == 1:
2080
+ src_raw = source.raw_tensor
2081
+ else:
2082
+ src_raw = torch.reshape(
2083
+ source.raw_tensor,
2084
+ # potentially merge batch dims all together
2085
+ [-1, in_dim.get_dim_value()] + [d.get_dim_value() for d in in_spatial_dims],
2086
+ )
2087
+ if padding == "same":
2088
+ raise NotImplementedError("transposed_conv with padding='same' not implemented")
2089
+ if padding == "valid":
2090
+ padding_val = 0
2091
+ else:
2092
+ raise ValueError(f"invalid padding {padding!r}, expected 'same' or 'valid'")
2093
+ if len(filter_size) == 1:
2094
+ out_raw = torch.nn.functional.conv_transpose1d(
2095
+ src_raw,
2096
+ weight=filter.raw_tensor,
2097
+ bias=bias.raw_tensor if bias is not None else None,
2098
+ stride=strides,
2099
+ padding=padding_val,
2100
+ output_padding=output_padding or 0,
2101
+ )
2102
+ elif len(filter_size) == 2:
2103
+ out_raw = torch.nn.functional.conv_transpose2d(
2104
+ src_raw,
2105
+ weight=filter.raw_tensor,
2106
+ bias=bias.raw_tensor if bias is not None else None,
2107
+ stride=strides,
2108
+ padding=padding_val,
2109
+ output_padding=output_padding or 0,
2110
+ )
2111
+ elif len(filter_size) == 3:
2112
+ out_raw = torch.nn.functional.conv_transpose3d(
2113
+ src_raw,
2114
+ weight=filter.raw_tensor,
2115
+ bias=bias.raw_tensor if bias is not None else None,
2116
+ stride=strides,
2117
+ padding=padding_val,
2118
+ output_padding=output_padding or 0,
2119
+ )
2120
+ else:
2121
+ raise ValueError(f"invalid number of filter dims {filter_size}, expected 1, 2, or 3")
2122
+ if remove_padding:
2123
+ if isinstance(remove_padding, int):
2124
+ remove_padding = [remove_padding] * len(out_spatial_dims)
2125
+ assert len(remove_padding) == len(out_spatial_dims)
2126
+ slices = [slice(None)] * out_raw.ndim
2127
+ for i, pad in enumerate(remove_padding):
2128
+ if pad > 0:
2129
+ slices[2 + i] = slice(0, -pad)
2130
+ out_raw = out_raw[tuple(slices)]
2131
+ out = Tensor(
2132
+ "transposed_conv",
2133
+ dims=batch_dims + [out_dim] + list(out_spatial_dims),
2134
+ dtype=TorchBackend.get_dtype_name_raw(out_raw),
2135
+ )
2136
+ if len(batch_dims) == 1:
2137
+ out.raw_tensor = out_raw
2138
+ else:
2139
+ out.raw_tensor = torch.reshape(out_raw, [d.get_dim_value() for d in out.dims])
2140
+ out.feature_dim = out_dim
2141
+ return out, out_spatial_dims
2142
+
2031
2143
  @staticmethod
2032
2144
  def pool(
2033
2145
  source: Tensor,
@@ -8,6 +8,10 @@ import os
8
8
  import sys
9
9
  import gc
10
10
  import subprocess
11
+ import signal
12
+ import time
13
+ import contextlib
14
+ import multiprocessing
11
15
  import torch
12
16
  from returnn.util.better_exchook import better_exchook
13
17
  from returnn.util.basic import human_bytes_size
@@ -26,36 +30,39 @@ def print_available_devices(*, file: Optional[TextIO] = None):
26
30
  print("CUDA_VISIBLE_DEVICES is set to %r." % os.environ["CUDA_VISIBLE_DEVICES"], file=file)
27
31
  cuda_visible_devs = dict(enumerate([int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",") if d]))
28
32
  else:
29
- if torch.cuda.is_available():
30
- print("CUDA_VISIBLE_DEVICES is not set.", file=file)
31
-
32
- if torch.cuda.is_available():
33
- print("Available CUDA devices:")
34
- count = torch.cuda.device_count()
35
- if cuda_visible_devs is not None and len(cuda_visible_devs) != count:
36
- print(
37
- f"(Mismatch between CUDA device count {count}"
38
- f" and CUDA_VISIBLE_DEVICES {cuda_visible_devs} count {len(cuda_visible_devs)}?)",
39
- file=file,
40
- )
41
- for i in range(count):
42
- print(f" {i + 1}/{count}: cuda:{i}", file=file)
43
- props = torch.cuda.get_device_properties(i)
44
- print(f" name: {props.name}", file=file)
45
- print(f" total_memory: {human_bytes_size(props.total_memory)}", file=file)
46
- print(f" capability: {props.major}.{props.minor}", file=file)
47
- if cuda_visible_devs is not None:
48
- if len(cuda_visible_devs) == count:
49
- dev_idx_s = cuda_visible_devs[i]
50
- else:
51
- dev_idx_s = "?"
33
+ with timeout("torch.cuda.is_available()"):
34
+ if torch.cuda.is_available():
35
+ print("CUDA_VISIBLE_DEVICES is not set.", file=file)
36
+
37
+ with timeout("torch.cuda.is_available()"):
38
+ if not torch.cuda.is_available():
39
+ print("(CUDA not available)", file=file)
40
+ return
41
+
42
+ print("Available CUDA devices:", file=file)
43
+ count = torch.cuda.device_count()
44
+ if cuda_visible_devs is not None and len(cuda_visible_devs) != count:
45
+ print(
46
+ f"(Mismatch between CUDA device count {count}"
47
+ f" and CUDA_VISIBLE_DEVICES {cuda_visible_devs} count {len(cuda_visible_devs)}?)",
48
+ file=file,
49
+ )
50
+ for i in range(count):
51
+ print(f" {i + 1}/{count}: cuda:{i}", file=file)
52
+ props = torch.cuda.get_device_properties(i)
53
+ print(f" name: {props.name}", file=file)
54
+ print(f" total_memory: {human_bytes_size(props.total_memory)}", file=file)
55
+ print(f" capability: {props.major}.{props.minor}", file=file)
56
+ if cuda_visible_devs is not None:
57
+ if len(cuda_visible_devs) == count:
58
+ dev_idx_s = cuda_visible_devs[i]
52
59
  else:
53
- dev_idx_s = i
54
- print(f" device_index: {dev_idx_s}", file=file)
55
- if not count:
56
- print(" (None)")
57
- else:
58
- print("(CUDA not available)")
60
+ dev_idx_s = "?"
61
+ else:
62
+ dev_idx_s = i
63
+ print(f" device_index: {dev_idx_s}", file=file)
64
+ if not count:
65
+ print(" (None)", file=file)
59
66
 
60
67
 
61
68
  def print_using_cuda_device_report(dev: Union[str, torch.device], *, file: Optional[TextIO] = None):
@@ -108,7 +115,7 @@ def diagnose_no_gpu() -> List[str]:
108
115
  except Exception as exc:
109
116
  print("nvidia-smi failed:", exc)
110
117
  better_exchook(*sys.exc_info(), debugshell=False)
111
- res.append(f"nvidia-smi failed")
118
+ res.append("nvidia-smi failed")
112
119
 
113
120
  return res
114
121
 
@@ -152,4 +159,31 @@ def garbage_collect():
152
159
  f"alloc {human_bytes_size(torch.cuda.memory_allocated())}",
153
160
  f"reserved {human_bytes_size(torch.cuda.memory_reserved())}",
154
161
  ]
155
- print(f"CUDA memory usage after triggered GC:", " ".join(stats))
162
+ print("CUDA memory usage after triggered GC:", " ".join(stats))
163
+
164
+
165
+ @contextlib.contextmanager
166
+ def timeout(info: str, *, seconds: int = 30):
167
+ """
168
+ Note: don't use signal handlers (e.g. signal.alarm) because unfortunately
169
+ potential hanging funcs will block the main thread and thus block the signal handler from executing.
170
+ Thus, we use a subprocess.
171
+
172
+ :param seconds:
173
+ :param info:
174
+ """
175
+ proc = multiprocessing.Process(
176
+ target=_timeout_handler, kwargs={"seconds": seconds, "proc_id": os.getpid(), "info": info}
177
+ )
178
+ proc.start()
179
+ try:
180
+ yield
181
+ finally:
182
+ proc.terminate()
183
+ proc.join()
184
+
185
+
186
+ def _timeout_handler(*, seconds: Union[float, int], proc_id: int, info: str):
187
+ time.sleep(seconds)
188
+ print(f"ERROR: {info}: Timeout handler after {seconds} seconds, killing proc {proc_id}.", file=sys.stderr)
189
+ os.kill(proc_id, signal.SIGABRT)
@@ -71,7 +71,13 @@ def help_on_torch_exception(
71
71
  if not count_frames:
72
72
  exc_ext.append("(No module call frames.)")
73
73
 
74
- if len(exc.args) == 1 and isinstance(exc.args[0], str) and not always_direct_print:
74
+ if (
75
+ # KeyError formatting would be wrong, showing `KeyError: "enc_spatial_dim\n\nStep idx: 0\..."`
76
+ not isinstance(exc, KeyError)
77
+ and len(exc.args) == 1
78
+ and isinstance(exc.args[0], str)
79
+ and not always_direct_print
80
+ ):
75
81
  exc.args = ("\n".join([exc.args[0], ""] + exc_ext),)
76
82
  else:
77
83
  for msg in exc_ext:
returnn/util/basic.py CHANGED
@@ -365,12 +365,9 @@ def get_checkpoint_filepattern(filepath):
365
365
  :return: CheckpointLoader compatible filepattern
366
366
  :rtype: str
367
367
  """
368
- if filepath.endswith(".meta"):
369
- return filepath[: -len(".meta")]
370
- elif filepath.endswith(".index"):
371
- return filepath[: -len(".index")]
372
- elif filepath.endswith(".pt"):
373
- return filepath[: -len(".pt")]
368
+ for ext in [".meta", ".index", ".pt"]:
369
+ if filepath.endswith(ext):
370
+ return filepath[: -len(ext)]
374
371
  return filepath
375
372
 
376
373
 
@@ -557,7 +554,9 @@ def get_tensorflow_version_tuple() -> Tuple[int, ...]:
557
554
  import tensorflow as tf # noqa
558
555
  import re
559
556
 
560
- return tuple([int(re.sub("(-rc[0-9]|-dev[0-9]*)", "", s)) for s in tf.__version__.split(".")])
557
+ # Remove unwanted suffixes from the TF version string (e.g. "2.20.0-dev0+selfbuilt")
558
+ filtered_version = [re.sub("(-rc[0-9]|-dev[0-9]*)(\\+selfbuilt)?", "", s) for s in tf.__version__.split(".")]
559
+ return tuple(int(v) for v in filtered_version)
561
560
 
562
561
 
563
562
  class ReportImportedDevModules:
@@ -1093,6 +1093,7 @@ def format_tb(
1093
1093
  with_color=None,
1094
1094
  with_vars=None,
1095
1095
  clear_frames=True,
1096
+ colorize=None,
1096
1097
  ):
1097
1098
  """
1098
1099
  Formats a traceback into a list of strings, each corresponding to one frame.
@@ -1110,11 +1111,14 @@ def format_tb(
1110
1111
  That will potentially fix some mem leaks regarding locals, so it can be important.
1111
1112
  Also see https://github.com/python/cpython/issues/113939.
1112
1113
  However, any further access to frame locals will not work (e.g., if you want to use a debugger afterward).
1114
+ :param colorize: for compat with Python >=3.13, currently ignored
1113
1115
  :return: list of strings, each corresponding to one frame in the traceback.
1114
1116
  Each string contains the file name, line number, function name, source code line, maybe relevant variables,
1115
1117
  etc., and a final newline.
1116
1118
  :rtype: list[str]
1117
1119
  """
1120
+ if colorize is not None and with_color is None:
1121
+ with_color = colorize
1118
1122
  color = Color(enable=with_color)
1119
1123
  output = _OutputLinesCollector(color=color)
1120
1124
 
@@ -0,0 +1,79 @@
1
+ """
2
+ Customized (derived) dict to pass as ``collected_outputs`` to some of the RF modules,
3
+ or potential other use cases.
4
+
5
+ You can predefine (by pattern) what kind of outputs you want to collect and store in this dict.
6
+ """
7
+
8
+ from typing import Optional, Union, Sequence
9
+ import fnmatch
10
+
11
+
12
+ class CollectOutputsDict(dict):
13
+ """
14
+ Customized (derived) dict, where you can predefine (by key pattern)
15
+ what kind of keys you want to collect and store in this dict.
16
+ Other keys will be ignored.
17
+ """
18
+
19
+ def __init__(self, *args, allowed_key_patterns: Optional[Sequence[str]] = None, **kwargs):
20
+ """
21
+ Initialize the CollectOutputsDict.
22
+
23
+ :param allowed_key_patterns:
24
+ List of key patterns (with wildcards) that are allowed to be stored in the dict.
25
+ If None, all keys are allowed.
26
+ """
27
+ super().__init__(*args, **kwargs)
28
+ self.allowed_key_patterns = allowed_key_patterns
29
+
30
+ def __setitem__(self, key, value):
31
+ """
32
+ Set an item in the dict if the key matches allowed patterns.
33
+ """
34
+ if self.is_key_allowed(key):
35
+ super().__setitem__(key, value)
36
+
37
+ def setdefault(self, key, default=None):
38
+ """
39
+ Set default value for a key if it matches allowed patterns.
40
+ """
41
+ if self.is_key_allowed(key):
42
+ return super().setdefault(key, default)
43
+ return None
44
+
45
+ def update(self, mapping, **kwargs):
46
+ """
47
+ Update the dict with another mapping, only adding allowed keys.
48
+ """
49
+ assert not kwargs
50
+ for key, value in mapping.items():
51
+ if self.is_key_allowed(key):
52
+ super().__setitem__(key, value)
53
+
54
+ def is_key_allowed(self, key: str) -> bool:
55
+ """
56
+ Check if the key matches any of the allowed patterns.
57
+
58
+ :param key:
59
+ :return: True if the key is allowed, False otherwise.
60
+ """
61
+ if self.allowed_key_patterns is None:
62
+ return True # If no patterns defined, allow all keys
63
+ for pattern in self.allowed_key_patterns:
64
+ if fnmatch.fnmatch(key, pattern):
65
+ return True
66
+ return False
67
+
68
+
69
+ def is_key_allowed_in_collect_outputs_dict(collect_outputs: Union[CollectOutputsDict, dict], key: str) -> bool:
70
+ """
71
+ Check if a key is allowed in the given CollectOutputsDict.
72
+
73
+ :param collect_outputs:
74
+ :param key:
75
+ :return: True if the key is allowed, False otherwise.
76
+ """
77
+ if isinstance(collect_outputs, CollectOutputsDict):
78
+ return collect_outputs.is_key_allowed(key)
79
+ return True # If it's a regular dict, all keys are allowed
returnn/util/debug.py CHANGED
@@ -704,7 +704,7 @@ def check_py_traces_rf_to_pt_equal(
704
704
  """
705
705
  import random
706
706
  import torch
707
- from returnn.tensor import Tensor, Dim
707
+ from returnn.tensor import Dim
708
708
  import returnn.frontend as rf
709
709
 
710
710
  # noinspection PyProtectedMember
@@ -715,9 +715,18 @@ def check_py_traces_rf_to_pt_equal(
715
715
  def _get_entry(trace, func, i, name, j):
716
716
  return trace[func][i][name][j]
717
717
 
718
+ def _get_entry_attr(trace, func, i, name, j):
719
+ name, attr = name.split(".", 1)
720
+ obj = trace[func][i][name][j]
721
+ return eval(f"{name}.{attr}", {name: obj})
722
+
718
723
  def _resolve_dim(dim: Union[Dim, str]) -> Dim:
719
724
  if isinstance(dim, Dim):
720
725
  return dim
726
+ elif isinstance(dim, str) and "." in dim:
727
+ dim = _get_entry_attr(trace_rf, *check_rf[:2], dim, -1)
728
+ assert isinstance(dim, Dim)
729
+ return dim
721
730
  elif isinstance(dim, str):
722
731
  dim = _get_entry(trace_rf, *check_rf[:2], dim, -1)
723
732
  assert isinstance(dim, Dim)
@@ -763,7 +772,7 @@ def check_py_traces_rf_to_pt_equal(
763
772
  if len(indices) > 5:
764
773
  msgs.append(" non-matching ...")
765
774
  non_matching.append("\n".join(msgs_prefix + msgs))
766
- print(f" mismatch!")
775
+ print(" mismatch!")
767
776
  for msg in msgs:
768
777
  print(msg)
769
778