returnn 1.20251013.113026__py3-none-any.whl → 1.20260109.93428__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (43) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/config.py +1 -1
  4. returnn/datasets/distrib_files.py +53 -1
  5. returnn/datasets/generating.py +3 -5
  6. returnn/datasets/lm.py +20 -0
  7. returnn/datasets/meta.py +179 -60
  8. returnn/datasets/postprocessing.py +597 -108
  9. returnn/datasets/util/vocabulary.py +90 -0
  10. returnn/frontend/array_.py +46 -0
  11. returnn/frontend/attention.py +54 -20
  12. returnn/frontend/conv.py +273 -54
  13. returnn/frontend/device.py +14 -1
  14. returnn/frontend/encoder/conformer.py +20 -0
  15. returnn/frontend/encoder/transformer.py +2 -0
  16. returnn/frontend/loss.py +40 -1
  17. returnn/frontend/math_.py +54 -14
  18. returnn/frontend/module.py +8 -1
  19. returnn/frontend/nested.py +5 -0
  20. returnn/native_op.cpp +80 -0
  21. returnn/sprint/cache.py +12 -13
  22. returnn/tensor/_dim_extra.py +39 -24
  23. returnn/tensor/utils.py +7 -4
  24. returnn/tf/frontend_layers/_backend.py +4 -3
  25. returnn/tf/layers/basic.py +15 -39
  26. returnn/tf/native_op.py +11 -58
  27. returnn/tf/network.py +1 -1
  28. returnn/tf/util/basic.py +19 -0
  29. returnn/torch/engine.py +67 -2
  30. returnn/torch/frontend/_backend.py +135 -13
  31. returnn/torch/frontend/bridge.py +61 -0
  32. returnn/torch/util/exception_helper.py +7 -1
  33. returnn/util/basic.py +6 -7
  34. returnn/util/better_exchook.py +4 -0
  35. returnn/util/collect_outputs_dict.py +79 -0
  36. returnn/util/debug.py +11 -2
  37. returnn/util/file_cache.py +15 -1
  38. returnn/util/task_system.py +1 -1
  39. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/METADATA +2 -2
  40. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/RECORD +43 -42
  41. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/LICENSE +0 -0
  42. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/WHEEL +0 -0
  43. {returnn-1.20251013.113026.dist-info → returnn-1.20260109.93428.dist-info}/top_level.txt +0 -0
returnn/torch/engine.py CHANGED
@@ -134,6 +134,14 @@ class Engine(EngineBase):
134
134
  self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
135
135
  self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)
136
136
 
137
+ if config.bool("use_tensorboard", False):
138
+ from torch.utils.tensorboard import SummaryWriter
139
+
140
+ self._tensorboard_writer = SummaryWriter()
141
+ self._tensorboard_opts = config.typed_value("tensorboard_opts", {})
142
+ else:
143
+ self._tensorboard_writer = None
144
+
137
145
  default_float_dtype = config.value("default_float_dtype", None)
138
146
  if default_float_dtype is not None:
139
147
  assert isinstance(default_float_dtype, str)
@@ -257,6 +265,9 @@ class Engine(EngineBase):
257
265
  self.init_train_epoch()
258
266
  self.train_epoch()
259
267
 
268
+ if self._tensorboard_writer:
269
+ self._tensorboard_writer.close()
270
+
260
271
  print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3)
261
272
 
262
273
  def init_train_epoch(self):
@@ -513,6 +524,18 @@ class Engine(EngineBase):
513
524
  batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
514
525
  log_memory_usage_device=self._device if self._log_memory_usage else None,
515
526
  )
527
+ if (
528
+ self._tensorboard_writer
529
+ and self.global_train_step % self._tensorboard_opts.get("log_every_n_train_steps", 100) == 0
530
+ ):
531
+ # write losses/errors to tensorboard
532
+ for key, val in eval_info.items():
533
+ self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
534
+ self._tensorboard_writer.add_scalar(
535
+ "train/learning_rate",
536
+ self._updater.get_effective_learning_rate(),
537
+ global_step=self.global_train_step,
538
+ )
516
539
 
517
540
  if self._stop_on_nonfinite_train_score:
518
541
  if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()):
@@ -702,12 +725,20 @@ class Engine(EngineBase):
702
725
  start_elapsed=step_end_time - eval_start_time,
703
726
  log_memory_usage_device=self._device if self._log_memory_usage else None,
704
727
  )
728
+
705
729
  step_idx += 1
706
730
 
707
731
  assert step_idx > 0, f"No data in dataset {dataset_name!r}."
708
732
  accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
709
733
  accumulated_losses_dict = self._maybe_extend_losses_info(accumulated_losses_dict)
710
734
 
735
+ if self._tensorboard_writer:
736
+ # write losses/errors to tensorboard
737
+ for key, val in accumulated_losses_dict.items():
738
+ self._tensorboard_writer.add_scalar(
739
+ f"{dataset_name}/{key}", val, global_step=self.global_train_step
740
+ )
741
+
711
742
  self.learning_rate_control.set_epoch_error(
712
743
  self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()}
713
744
  )
@@ -899,7 +930,7 @@ class Engine(EngineBase):
899
930
  if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
900
931
  filename = model_epoch_filename
901
932
  print("Load model %s" % (filename,), file=log.v4)
902
- checkpoint_state = torch.load(filename, map_location=self._device)
933
+ checkpoint_state = _torch_load(filename, device=self._device)
903
934
  if epoch is None:
904
935
  epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
905
936
  step = checkpoint_state.get("step", 1)
@@ -999,7 +1030,7 @@ class Engine(EngineBase):
999
1030
  print("(No relevant parameters matching.)", file=log.v3)
1000
1031
  continue
1001
1032
  print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
1002
- preload_model_state = torch.load(opts["filename"], map_location=self._device)
1033
+ preload_model_state = _torch_load(opts["filename"], device=self._device)
1003
1034
  if opts.get("checkpoint_key", "model") is not None:
1004
1035
  # This can be used if an external checkpoint saves a checkpoint a different structure that just the
1005
1036
  # model state dict. E.g., if a checkpoint is created using
@@ -1032,6 +1063,28 @@ class Engine(EngineBase):
1032
1063
  preload_model_state_keys = set(preload_model_state.keys())
1033
1064
  loaded_state_keys.update(preload_model_state.keys())
1034
1065
  missing_keys.difference_update(preload_model_state.keys())
1066
+
1067
+ custom_missing_load_func = opts.get("custom_missing_load_func")
1068
+ if custom_missing_load_func:
1069
+ custom_missing_vars_map = {}
1070
+ for var_name in missing_keys_preload:
1071
+ var_shape = self._pt_model.state_dict()[var_name].shape
1072
+ var_val = custom_missing_load_func(
1073
+ name=var_name,
1074
+ shape=var_shape,
1075
+ preload_model_state=preload_model_state,
1076
+ **util.get_fwd_compat_kwargs(),
1077
+ )
1078
+ if var_val is not None:
1079
+ assert var_val.shape == var_shape
1080
+ custom_missing_vars_map[var_name] = var_val
1081
+ preload_model_state.update(custom_missing_vars_map)
1082
+ missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
1083
+ preload_model_state, strict=False
1084
+ )
1085
+ loaded_state_keys.update(preload_model_state.keys())
1086
+ missing_keys.difference_update(preload_model_state.keys())
1087
+
1035
1088
  del preload_model_state
1036
1089
  gc.collect()
1037
1090
 
@@ -1669,3 +1722,15 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1669
1722
  p=p,
1670
1723
  ).item()
1671
1724
  )
1725
+
1726
+
1727
+ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
1728
+ # Might resolve PtCheckpoint or Sisyphus Path objects or so.
1729
+ filename = os.fspath(filename)
1730
+
1731
+ if filename.endswith(".safetensors"):
1732
+ from safetensors.torch import load_file as safetensors_load
1733
+
1734
+ return safetensors_load(filename, device=device)
1735
+
1736
+ return torch.load(filename, map_location=device)
@@ -1166,20 +1166,29 @@ class TorchBackend(Backend[torch.Tensor]):
1166
1166
  if start is None:
1167
1167
  start = 0
1168
1168
  if isinstance(size, Dim):
1169
+ assert end is None
1169
1170
  size = size.get_dim_value()
1170
1171
  elif isinstance(size, Tensor):
1172
+ assert end is None
1171
1173
  assert size.dims == () # scalar
1172
1174
  size = size.raw_tensor
1173
- if size is not None:
1174
- assert end is None
1175
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1176
- else:
1175
+ elif isinstance(size, int):
1176
+ pass
1177
+ elif size is None:
1177
1178
  if isinstance(end, Tensor):
1178
1179
  assert end.dims == ()
1179
1180
  end = end.raw_tensor
1180
- if end is None:
1181
+ elif isinstance(end, int):
1182
+ if end < 0:
1183
+ end += axis.get_dim_value()
1184
+ elif end is None:
1181
1185
  end = axis.get_dim_value()
1182
- out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=end - start)
1186
+ else:
1187
+ raise TypeError(f"slice: unsupported type for end: {type(end)}")
1188
+ size = end - start
1189
+ else:
1190
+ raise TypeError(f"slice: unsupported type for size: {type(size)}")
1191
+ out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
1183
1192
  return out
1184
1193
 
1185
1194
  @staticmethod
@@ -1352,12 +1361,24 @@ class TorchBackend(Backend[torch.Tensor]):
1352
1361
  a_dims = a.dims
1353
1362
  b_dims = b.dims
1354
1363
 
1355
- assert all(dim in a_dims for dim in reduce), (
1356
- f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})"
1357
- )
1358
- assert all(dim in b_dims for dim in reduce), (
1359
- f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})"
1360
- )
1364
+ if not all(dim in a_dims for dim in reduce) or not all(dim in b_dims for dim in reduce):
1365
+ # revert to the generic einsum implementation
1366
+ assert all(dim in a_dims + b_dims for dim in reduce), "Some reduce Dims not in a or b."
1367
+ result_dims = [dim for dim in a_dims if dim not in reduce] + [
1368
+ dim for dim in b_dims if dim not in reduce and dim not in a_dims
1369
+ ]
1370
+ map_to_letter = {}
1371
+ for dim in a_dims + b_dims:
1372
+ if dim not in map_to_letter:
1373
+ map_to_letter[dim] = chr(97 + len(map_to_letter)) # 'a', 'b', 'c', ...
1374
+ a_subscript = "".join(map_to_letter[dim] for dim in a_dims)
1375
+ b_subscript = "".join(map_to_letter[dim] for dim in b_dims)
1376
+ out_subscript = "".join(map_to_letter[dim] for dim in result_dims)
1377
+ raw_result = torch.einsum(f"{a_subscript},{b_subscript}->{out_subscript}", a.raw_tensor, b.raw_tensor)
1378
+ result_tensor = Tensor(
1379
+ "einsum", dims=result_dims, raw_tensor=raw_result, dtype=TorchBackend.get_dtype_name_raw(raw_result)
1380
+ )
1381
+ return result_tensor
1361
1382
 
1362
1383
  if len(reduce) > 1:
1363
1384
  reduce = list(reduce)
@@ -1767,6 +1788,9 @@ class TorchBackend(Backend[torch.Tensor]):
1767
1788
  remaining_dims = [d for d in tensor.dims if d not in mask.dims]
1768
1789
  tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
1769
1790
  in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
1791
+ if any(in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)):
1792
+ # unbroadcast
1793
+ in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
1770
1794
  if mask.raw_tensor.device.type == "meta":
1771
1795
  # This is not supported, but also, we would anyway not know the out shape.
1772
1796
  # However, instead of erroring, just assume some dummy mask.
@@ -1920,7 +1944,7 @@ class TorchBackend(Backend[torch.Tensor]):
1920
1944
  if not out_spatial_dims:
1921
1945
  out_spatial_dims = rf.make_conv_out_spatial_dims(
1922
1946
  in_spatial_dims=in_spatial_dims,
1923
- filter_size=[d.dimension for d in filter_size],
1947
+ filter_size=filter_size,
1924
1948
  strides=strides or 1,
1925
1949
  dilation_rate=dilation_rate or 1,
1926
1950
  padding=padding,
@@ -2033,6 +2057,104 @@ class TorchBackend(Backend[torch.Tensor]):
2033
2057
  out.feature_dim = out_dim
2034
2058
  return out, out_spatial_dims
2035
2059
 
2060
+ # noinspection PyShadowingBuiltins
2061
+ @staticmethod
2062
+ def transposed_conv(
2063
+ source: Tensor,
2064
+ *,
2065
+ in_dim: Dim,
2066
+ out_dim: Dim,
2067
+ in_spatial_dims: Sequence[Dim],
2068
+ out_spatial_dims: Optional[Sequence[Dim]] = None,
2069
+ filter: Tensor,
2070
+ filter_size: Sequence[Dim],
2071
+ padding: str,
2072
+ remove_padding: Union[Sequence[int], int] = 0,
2073
+ output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
2074
+ strides: Optional[Sequence[int]] = None,
2075
+ bias: Optional[Tensor] = None,
2076
+ ) -> Tuple[Tensor, Sequence[Dim]]:
2077
+ """transposed convolution"""
2078
+ if not out_spatial_dims:
2079
+ out_spatial_dims = rf.make_transposed_conv_out_spatial_dims(
2080
+ in_spatial_dims=in_spatial_dims,
2081
+ filter_size=filter_size,
2082
+ strides=strides,
2083
+ padding=padding,
2084
+ output_padding=output_padding,
2085
+ )
2086
+ assert remove_padding == 0 # not implemented yet otherwise...
2087
+ if strides is None:
2088
+ strides = [fs.dimension for fs in filter_size]
2089
+ filter_dims = (in_dim, out_dim) + tuple(filter_size)
2090
+ filter = filter.copy_transpose(filter_dims)
2091
+ batch_dims = [d for d in source.dims if d not in (in_dim,) + tuple(in_spatial_dims)]
2092
+ # Torch conv expects (N,C,<spatial dims>) as shape.
2093
+ source = source.copy_transpose(batch_dims + [in_dim] + list(in_spatial_dims))
2094
+ if len(batch_dims) == 1:
2095
+ src_raw = source.raw_tensor
2096
+ else:
2097
+ src_raw = torch.reshape(
2098
+ source.raw_tensor,
2099
+ # potentially merge batch dims all together
2100
+ [-1, in_dim.get_dim_value()] + [d.get_dim_value() for d in in_spatial_dims],
2101
+ )
2102
+ if padding == "same":
2103
+ raise NotImplementedError("transposed_conv with padding='same' not implemented")
2104
+ if padding == "valid":
2105
+ padding_val = 0
2106
+ else:
2107
+ raise ValueError(f"invalid padding {padding!r}, expected 'same' or 'valid'")
2108
+ if len(filter_size) == 1:
2109
+ out_raw = torch.nn.functional.conv_transpose1d(
2110
+ src_raw,
2111
+ weight=filter.raw_tensor,
2112
+ bias=bias.raw_tensor if bias is not None else None,
2113
+ stride=strides,
2114
+ padding=padding_val,
2115
+ output_padding=output_padding or 0,
2116
+ )
2117
+ elif len(filter_size) == 2:
2118
+ out_raw = torch.nn.functional.conv_transpose2d(
2119
+ src_raw,
2120
+ weight=filter.raw_tensor,
2121
+ bias=bias.raw_tensor if bias is not None else None,
2122
+ stride=strides,
2123
+ padding=padding_val,
2124
+ output_padding=output_padding or 0,
2125
+ )
2126
+ elif len(filter_size) == 3:
2127
+ out_raw = torch.nn.functional.conv_transpose3d(
2128
+ src_raw,
2129
+ weight=filter.raw_tensor,
2130
+ bias=bias.raw_tensor if bias is not None else None,
2131
+ stride=strides,
2132
+ padding=padding_val,
2133
+ output_padding=output_padding or 0,
2134
+ )
2135
+ else:
2136
+ raise ValueError(f"invalid number of filter dims {filter_size}, expected 1, 2, or 3")
2137
+ if remove_padding:
2138
+ if isinstance(remove_padding, int):
2139
+ remove_padding = [remove_padding] * len(out_spatial_dims)
2140
+ assert len(remove_padding) == len(out_spatial_dims)
2141
+ slices = [slice(None)] * out_raw.ndim
2142
+ for i, pad in enumerate(remove_padding):
2143
+ if pad > 0:
2144
+ slices[2 + i] = slice(0, -pad)
2145
+ out_raw = out_raw[tuple(slices)]
2146
+ out = Tensor(
2147
+ "transposed_conv",
2148
+ dims=batch_dims + [out_dim] + list(out_spatial_dims),
2149
+ dtype=TorchBackend.get_dtype_name_raw(out_raw),
2150
+ )
2151
+ if len(batch_dims) == 1:
2152
+ out.raw_tensor = out_raw
2153
+ else:
2154
+ out.raw_tensor = torch.reshape(out_raw, [d.get_dim_value() for d in out.dims])
2155
+ out.feature_dim = out_dim
2156
+ return out, out_spatial_dims
2157
+
2036
2158
  @staticmethod
2037
2159
  def pool(
2038
2160
  source: Tensor,
@@ -136,6 +136,15 @@ class RFModuleAsPTModule(torch.nn.Module):
136
136
  def _get_name(self):
137
137
  return self._rf_module.__class__.__name__ + "[RF→PT]"
138
138
 
139
+ def __repr__(self) -> str:
140
+ """
141
+ Return a custom repr for Sequential/ModuleList that compresses repeated module representations if possible,
142
+ otherwise fallback to default behavior.
143
+ """
144
+ if _can_use_compact_repr(self):
145
+ return _repr_compact(self)
146
+ return super().__repr__()
147
+
139
148
  @property
140
149
  def rf_module(self) -> rf.Module:
141
150
  """RF module"""
@@ -193,3 +202,55 @@ class RFModuleAsPTModule(torch.nn.Module):
193
202
  # See similar logic in torch.nn.Module._apply.
194
203
  pt_param = torch.nn.Parameter(tensor, tensor.requires_grad)
195
204
  rf_param.raw_tensor = pt_param
205
+
206
+
207
+ def _can_use_compact_repr(self: RFModuleAsPTModule) -> bool:
208
+ return list(self._modules.keys()) == [str(i) for i in range(len(self._modules))]
209
+
210
+
211
+ def _repr_compact(self: RFModuleAsPTModule) -> str:
212
+ """
213
+ Return a custom repr for Sequential/ModuleList that compresses repeated module representations.
214
+ Code copied and adapted from torch.nn.ModuleList.__repr__.
215
+ """
216
+ list_of_reprs = [repr(item) for item in self._modules.values()]
217
+ if len(list_of_reprs) == 0:
218
+ return self._get_name() + "()"
219
+
220
+ start_end_indices = [[0, 0]]
221
+ repeated_blocks = [list_of_reprs[0]]
222
+ for i, r in enumerate(list_of_reprs[1:], 1):
223
+ if r == repeated_blocks[-1]:
224
+ start_end_indices[-1][1] += 1
225
+ continue
226
+
227
+ start_end_indices.append([i, i])
228
+ repeated_blocks.append(r)
229
+
230
+ lines = []
231
+ main_str = self._get_name() + "("
232
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
233
+ local_repr = f"({start_id}): {b}" # default repr
234
+
235
+ if start_id != end_id:
236
+ n = end_id - start_id + 1
237
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
238
+
239
+ local_repr = _add_indent(local_repr, 2)
240
+ lines.append(local_repr)
241
+
242
+ main_str += "\n " + "\n ".join(lines) + "\n"
243
+ main_str += ")"
244
+ return main_str
245
+
246
+
247
+ def _add_indent(s_: str, num_spaces: int) -> str:
248
+ s = s_.split("\n")
249
+ # don't do anything for single-line stuff
250
+ if len(s) == 1:
251
+ return s_
252
+ first = s.pop(0)
253
+ s = [(num_spaces * " ") + line for line in s]
254
+ s = "\n".join(s)
255
+ s = first + "\n" + s
256
+ return s
@@ -71,7 +71,13 @@ def help_on_torch_exception(
71
71
  if not count_frames:
72
72
  exc_ext.append("(No module call frames.)")
73
73
 
74
- if len(exc.args) == 1 and isinstance(exc.args[0], str) and not always_direct_print:
74
+ if (
75
+ # KeyError formatting would be wrong, showing `KeyError: "enc_spatial_dim\n\nStep idx: 0\..."`
76
+ not isinstance(exc, KeyError)
77
+ and len(exc.args) == 1
78
+ and isinstance(exc.args[0], str)
79
+ and not always_direct_print
80
+ ):
75
81
  exc.args = ("\n".join([exc.args[0], ""] + exc_ext),)
76
82
  else:
77
83
  for msg in exc_ext:
returnn/util/basic.py CHANGED
@@ -365,12 +365,9 @@ def get_checkpoint_filepattern(filepath):
365
365
  :return: CheckpointLoader compatible filepattern
366
366
  :rtype: str
367
367
  """
368
- if filepath.endswith(".meta"):
369
- return filepath[: -len(".meta")]
370
- elif filepath.endswith(".index"):
371
- return filepath[: -len(".index")]
372
- elif filepath.endswith(".pt"):
373
- return filepath[: -len(".pt")]
368
+ for ext in [".meta", ".index", ".pt"]:
369
+ if filepath.endswith(ext):
370
+ return filepath[: -len(ext)]
374
371
  return filepath
375
372
 
376
373
 
@@ -557,7 +554,9 @@ def get_tensorflow_version_tuple() -> Tuple[int, ...]:
557
554
  import tensorflow as tf # noqa
558
555
  import re
559
556
 
560
- return tuple([int(re.sub("(-rc[0-9]|-dev[0-9]*)", "", s)) for s in tf.__version__.split(".")])
557
+ # Remove unwanted suffixes from the TF version string (e.g. "2.20.0-dev0+selfbuilt")
558
+ filtered_version = [re.sub("(-rc[0-9]|-dev[0-9]*)(\\+selfbuilt)?", "", s) for s in tf.__version__.split(".")]
559
+ return tuple(int(v) for v in filtered_version)
561
560
 
562
561
 
563
562
  class ReportImportedDevModules:
@@ -1093,6 +1093,7 @@ def format_tb(
1093
1093
  with_color=None,
1094
1094
  with_vars=None,
1095
1095
  clear_frames=True,
1096
+ colorize=None,
1096
1097
  ):
1097
1098
  """
1098
1099
  Formats a traceback into a list of strings, each corresponding to one frame.
@@ -1110,11 +1111,14 @@ def format_tb(
1110
1111
  That will potentially fix some mem leaks regarding locals, so it can be important.
1111
1112
  Also see https://github.com/python/cpython/issues/113939.
1112
1113
  However, any further access to frame locals will not work (e.g., if you want to use a debugger afterward).
1114
+ :param colorize: for compat with Python >=3.13, currently ignored
1113
1115
  :return: list of strings, each corresponding to one frame in the traceback.
1114
1116
  Each string contains the file name, line number, function name, source code line, maybe relevant variables,
1115
1117
  etc., and a final newline.
1116
1118
  :rtype: list[str]
1117
1119
  """
1120
+ if colorize is not None and with_color is None:
1121
+ with_color = colorize
1118
1122
  color = Color(enable=with_color)
1119
1123
  output = _OutputLinesCollector(color=color)
1120
1124
 
@@ -0,0 +1,79 @@
1
+ """
2
+ Customized (derived) dict to pass as ``collected_outputs`` to some of the RF modules,
3
+ or potential other use cases.
4
+
5
+ You can predefine (by pattern) what kind of outputs you want to collect and store in this dict.
6
+ """
7
+
8
+ from typing import Optional, Union, Sequence
9
+ import fnmatch
10
+
11
+
12
+ class CollectOutputsDict(dict):
13
+ """
14
+ Customized (derived) dict, where you can predefine (by key pattern)
15
+ what kind of keys you want to collect and store in this dict.
16
+ Other keys will be ignored.
17
+ """
18
+
19
+ def __init__(self, *args, allowed_key_patterns: Optional[Sequence[str]] = None, **kwargs):
20
+ """
21
+ Initialize the CollectOutputsDict.
22
+
23
+ :param allowed_key_patterns:
24
+ List of key patterns (with wildcards) that are allowed to be stored in the dict.
25
+ If None, all keys are allowed.
26
+ """
27
+ super().__init__(*args, **kwargs)
28
+ self.allowed_key_patterns = allowed_key_patterns
29
+
30
+ def __setitem__(self, key, value):
31
+ """
32
+ Set an item in the dict if the key matches allowed patterns.
33
+ """
34
+ if self.is_key_allowed(key):
35
+ super().__setitem__(key, value)
36
+
37
+ def setdefault(self, key, default=None):
38
+ """
39
+ Set default value for a key if it matches allowed patterns.
40
+ """
41
+ if self.is_key_allowed(key):
42
+ return super().setdefault(key, default)
43
+ return None
44
+
45
+ def update(self, mapping, **kwargs):
46
+ """
47
+ Update the dict with another mapping, only adding allowed keys.
48
+ """
49
+ assert not kwargs
50
+ for key, value in mapping.items():
51
+ if self.is_key_allowed(key):
52
+ super().__setitem__(key, value)
53
+
54
+ def is_key_allowed(self, key: str) -> bool:
55
+ """
56
+ Check if the key matches any of the allowed patterns.
57
+
58
+ :param key:
59
+ :return: True if the key is allowed, False otherwise.
60
+ """
61
+ if self.allowed_key_patterns is None:
62
+ return True # If no patterns defined, allow all keys
63
+ for pattern in self.allowed_key_patterns:
64
+ if fnmatch.fnmatch(key, pattern):
65
+ return True
66
+ return False
67
+
68
+
69
+ def is_key_allowed_in_collect_outputs_dict(collect_outputs: Union[CollectOutputsDict, dict], key: str) -> bool:
70
+ """
71
+ Check if a key is allowed in the given CollectOutputsDict.
72
+
73
+ :param collect_outputs:
74
+ :param key:
75
+ :return: True if the key is allowed, False otherwise.
76
+ """
77
+ if isinstance(collect_outputs, CollectOutputsDict):
78
+ return collect_outputs.is_key_allowed(key)
79
+ return True # If it's a regular dict, all keys are allowed
returnn/util/debug.py CHANGED
@@ -704,7 +704,7 @@ def check_py_traces_rf_to_pt_equal(
704
704
  """
705
705
  import random
706
706
  import torch
707
- from returnn.tensor import Tensor, Dim
707
+ from returnn.tensor import Dim
708
708
  import returnn.frontend as rf
709
709
 
710
710
  # noinspection PyProtectedMember
@@ -715,9 +715,18 @@ def check_py_traces_rf_to_pt_equal(
715
715
  def _get_entry(trace, func, i, name, j):
716
716
  return trace[func][i][name][j]
717
717
 
718
+ def _get_entry_attr(trace, func, i, name, j):
719
+ name, attr = name.split(".", 1)
720
+ obj = trace[func][i][name][j]
721
+ return eval(f"{name}.{attr}", {name: obj})
722
+
718
723
  def _resolve_dim(dim: Union[Dim, str]) -> Dim:
719
724
  if isinstance(dim, Dim):
720
725
  return dim
726
+ elif isinstance(dim, str) and "." in dim:
727
+ dim = _get_entry_attr(trace_rf, *check_rf[:2], dim, -1)
728
+ assert isinstance(dim, Dim)
729
+ return dim
721
730
  elif isinstance(dim, str):
722
731
  dim = _get_entry(trace_rf, *check_rf[:2], dim, -1)
723
732
  assert isinstance(dim, Dim)
@@ -763,7 +772,7 @@ def check_py_traces_rf_to_pt_equal(
763
772
  if len(indices) > 5:
764
773
  msgs.append(" non-matching ...")
765
774
  non_matching.append("\n".join(msgs_prefix + msgs))
766
- print(f" mismatch!")
775
+ print(" mismatch!")
767
776
  for msg in msgs:
768
777
  print(msg)
769
778
 
@@ -426,7 +426,21 @@ class FileCache:
426
426
  orig_mtime_ns = os.stat(src_filename).st_mtime_ns
427
427
  FileInfo(mtime_ns=orig_mtime_ns).save(info_file_name)
428
428
 
429
- _copy_with_prealloc(src_filename, dst_tmp_filename)
429
+ try:
430
+ _copy_with_prealloc(src_filename, dst_tmp_filename)
431
+ except Exception:
432
+ # Cleanup if it was created already.
433
+ # That avoids some of the ambiguity of the existence of the .copy file.
434
+ # https://github.com/rwth-i6/returnn/issues/1785
435
+ try:
436
+ os.remove(dst_tmp_filename)
437
+ except FileNotFoundError:
438
+ pass
439
+ try:
440
+ os.remove(info_file_name)
441
+ except FileNotFoundError: # not really expected here, but safe to ignore
442
+ pass
443
+ raise
430
444
  os.rename(dst_tmp_filename, dst_filename)
431
445
 
432
446
  @staticmethod
@@ -671,7 +671,7 @@ class Pickler(_BasePickler):
671
671
  return
672
672
  # For some reason, Numpy fromstring/tostring is faster than Numpy loads/dumps.
673
673
  self.save(make_numpy_ndarray_fromstring)
674
- self.save((obj.tostring(), str(obj.dtype), obj.shape))
674
+ self.save((obj.tobytes(), str(obj.dtype), obj.shape))
675
675
  self.write(pickle.REDUCE)
676
676
 
677
677
  dispatch[numpy.ndarray] = save_ndarray
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20251013.113026
3
+ Version: 1.20260109.93428
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -36,7 +36,7 @@ Welcome to RETURNN
36
36
  `RETURNN paper 2018 <https://arxiv.org/abs/1805.05225>`_.
37
37
 
38
38
  RETURNN - RWTH extensible training framework for universal recurrent neural networks,
39
- is a Theano/TensorFlow-based implementation of modern recurrent neural network architectures.
39
+ is a PyTorch/TensorFlow-based implementation of modern recurrent neural network architectures.
40
40
  It is optimized for fast and reliable training of recurrent neural networks in a multi-GPU environment.
41
41
 
42
42
  The high-level features and goals of RETURNN are: