returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/config.py +1 -1
- returnn/datasets/basic.py +29 -13
- returnn/datasets/distrib_files.py +61 -3
- returnn/datasets/generating.py +12 -21
- returnn/datasets/huggingface.py +434 -0
- returnn/datasets/lm.py +20 -0
- returnn/datasets/meta.py +179 -60
- returnn/datasets/multi_proc.py +1 -1
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/text_dict.py +1 -1
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/_backend.py +7 -0
- returnn/frontend/array_.py +54 -1
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/decoder/transformer.py +36 -17
- returnn/frontend/encoder/conformer.py +1 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +40 -1
- returnn/frontend/module.py +8 -1
- returnn/frontend/nested.py +9 -0
- returnn/native_op.cpp +80 -0
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +51 -29
- returnn/tensor/_tensor_extra.py +6 -1
- returnn/tensor/utils.py +7 -4
- returnn/tf/frontend_layers/_backend.py +11 -2
- returnn/tf/frontend_low_level/_backend.py +15 -0
- returnn/tf/layers/basic.py +16 -38
- returnn/tf/native_op.py +11 -58
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +19 -0
- returnn/torch/data/returnn_dataset_wrapper.py +9 -3
- returnn/torch/engine.py +67 -2
- returnn/torch/frontend/_backend.py +119 -7
- returnn/torch/util/diagnose_gpu.py +65 -31
- returnn/torch/util/exception_helper.py +7 -1
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/collect_outputs_dict.py +79 -0
- returnn/util/debug.py +11 -2
- returnn/util/file_cache.py +42 -4
- returnn/util/task_system.py +1 -1
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
returnn/tf/util/basic.py
CHANGED
|
@@ -2784,6 +2784,10 @@ class CudaEnv:
|
|
|
2784
2784
|
self.cuda_path = None
|
|
2785
2785
|
if self.verbose_find_cuda:
|
|
2786
2786
|
print("CUDA disabled via env DISABLE_CUDA.")
|
|
2787
|
+
elif os.environ.get("CUDA_VISIBLE_DEVICES", None) in ["", "-1"]:
|
|
2788
|
+
self.cuda_path = None
|
|
2789
|
+
if self.verbose_find_cuda:
|
|
2790
|
+
print(f"CUDA disabled via env CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']!r}.")
|
|
2787
2791
|
else:
|
|
2788
2792
|
self.cuda_path = self._find_cuda_path()
|
|
2789
2793
|
if self.verbose_find_cuda:
|
|
@@ -3020,6 +3024,21 @@ class OpCodeCompiler(NativeCodeCompiler):
|
|
|
3020
3024
|
ld_flags += tf.sysconfig.get_link_flags()
|
|
3021
3025
|
elif have_min_tf_version((1, 4)):
|
|
3022
3026
|
ld_flags += ["-L%s" % tf.sysconfig.get_lib(), "-ltensorflow_framework"]
|
|
3027
|
+
if have_min_tf_version((2, 20)):
|
|
3028
|
+
# TF 2.20 removed TF_MAJOR_VERSION and co from version.h,
|
|
3029
|
+
# and one is supposed to define these macros externally.
|
|
3030
|
+
# Also, release_version.h was added to define TF_VERSION_STRING based on this (if needed).
|
|
3031
|
+
# https://github.com/tensorflow/tensorflow/commit/c8f0e0620e5678d0f165a07e64114024a966ab7f
|
|
3032
|
+
major, minor, patch = tf.__version__.split(".", 2)
|
|
3033
|
+
patch, suffix = patch.split("-", 1) if "-" in patch else (patch, "")
|
|
3034
|
+
c_macro_defines.update(
|
|
3035
|
+
{
|
|
3036
|
+
"TF_MAJOR_VERSION": major,
|
|
3037
|
+
"TF_MINOR_VERSION": minor,
|
|
3038
|
+
"TF_PATCH_VERSION": patch,
|
|
3039
|
+
"TF_VERSION_SUFFIX": suffix,
|
|
3040
|
+
}
|
|
3041
|
+
)
|
|
3023
3042
|
use_cxx11_abi = getattr(getattr(tf, "sysconfig", tf), "CXX11_ABI_FLAG", getattr(tf, "CXX11_ABI_FLAG", False))
|
|
3024
3043
|
super(OpCodeCompiler, self).__init__(
|
|
3025
3044
|
include_paths=include_paths,
|
|
@@ -20,12 +20,18 @@ ResetCallbackT = Callable[[], None]
|
|
|
20
20
|
class ReturnnDatasetResetDefaultEpochCounterCallback:
|
|
21
21
|
"""
|
|
22
22
|
Default for reset_callback.
|
|
23
|
-
Has an internal counter for the epoch, starting at epoch 1 (RETURNN convention).
|
|
23
|
+
Has an internal counter for the epoch, starting by default at epoch 1 (RETURNN convention).
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
|
-
def __init__(self, dataset: ReturnnDataset):
|
|
26
|
+
def __init__(self, dataset: ReturnnDataset, *, epoch0: int = 0):
|
|
27
|
+
"""
|
|
28
|
+
:param dataset: RETURNN dataset.
|
|
29
|
+
:param epoch0: Epoch from which the dataset sequence ordering should start.
|
|
30
|
+
It will actually be epoch0+1 for the first epoch, since :func:`__call__` will increment it.
|
|
31
|
+
By default 0 since next :func:`__call__` will increment, thus we start at epoch 1.
|
|
32
|
+
"""
|
|
27
33
|
self.dataset = dataset
|
|
28
|
-
self.epoch =
|
|
34
|
+
self.epoch = epoch0
|
|
29
35
|
|
|
30
36
|
def __call__(self):
|
|
31
37
|
# dataset is likely a copy of the original dataset, either in the main process or in a worker process
|
returnn/torch/engine.py
CHANGED
|
@@ -134,6 +134,14 @@ class Engine(EngineBase):
|
|
|
134
134
|
self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
|
|
135
135
|
self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)
|
|
136
136
|
|
|
137
|
+
if config.bool("use_tensorboard", False):
|
|
138
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
139
|
+
|
|
140
|
+
self._tensorboard_writer = SummaryWriter()
|
|
141
|
+
self._tensorboard_opts = config.typed_value("tensorboard_opts", {})
|
|
142
|
+
else:
|
|
143
|
+
self._tensorboard_writer = None
|
|
144
|
+
|
|
137
145
|
default_float_dtype = config.value("default_float_dtype", None)
|
|
138
146
|
if default_float_dtype is not None:
|
|
139
147
|
assert isinstance(default_float_dtype, str)
|
|
@@ -257,6 +265,9 @@ class Engine(EngineBase):
|
|
|
257
265
|
self.init_train_epoch()
|
|
258
266
|
self.train_epoch()
|
|
259
267
|
|
|
268
|
+
if self._tensorboard_writer:
|
|
269
|
+
self._tensorboard_writer.close()
|
|
270
|
+
|
|
260
271
|
print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3)
|
|
261
272
|
|
|
262
273
|
def init_train_epoch(self):
|
|
@@ -513,6 +524,18 @@ class Engine(EngineBase):
|
|
|
513
524
|
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
|
|
514
525
|
log_memory_usage_device=self._device if self._log_memory_usage else None,
|
|
515
526
|
)
|
|
527
|
+
if (
|
|
528
|
+
self._tensorboard_writer
|
|
529
|
+
and self.global_train_step % self._tensorboard_opts.get("log_every_n_train_steps", 100) == 0
|
|
530
|
+
):
|
|
531
|
+
# write losses/errors to tensorboard
|
|
532
|
+
for key, val in eval_info.items():
|
|
533
|
+
self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
|
|
534
|
+
self._tensorboard_writer.add_scalar(
|
|
535
|
+
"train/learning_rate",
|
|
536
|
+
self._updater.get_effective_learning_rate(),
|
|
537
|
+
global_step=self.global_train_step,
|
|
538
|
+
)
|
|
516
539
|
|
|
517
540
|
if self._stop_on_nonfinite_train_score:
|
|
518
541
|
if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()):
|
|
@@ -702,12 +725,20 @@ class Engine(EngineBase):
|
|
|
702
725
|
start_elapsed=step_end_time - eval_start_time,
|
|
703
726
|
log_memory_usage_device=self._device if self._log_memory_usage else None,
|
|
704
727
|
)
|
|
728
|
+
|
|
705
729
|
step_idx += 1
|
|
706
730
|
|
|
707
731
|
assert step_idx > 0, f"No data in dataset {dataset_name!r}."
|
|
708
732
|
accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
|
|
709
733
|
accumulated_losses_dict = self._maybe_extend_losses_info(accumulated_losses_dict)
|
|
710
734
|
|
|
735
|
+
if self._tensorboard_writer:
|
|
736
|
+
# write losses/errors to tensorboard
|
|
737
|
+
for key, val in accumulated_losses_dict.items():
|
|
738
|
+
self._tensorboard_writer.add_scalar(
|
|
739
|
+
f"{dataset_name}/{key}", val, global_step=self.global_train_step
|
|
740
|
+
)
|
|
741
|
+
|
|
711
742
|
self.learning_rate_control.set_epoch_error(
|
|
712
743
|
self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()}
|
|
713
744
|
)
|
|
@@ -899,7 +930,7 @@ class Engine(EngineBase):
|
|
|
899
930
|
if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
|
|
900
931
|
filename = model_epoch_filename
|
|
901
932
|
print("Load model %s" % (filename,), file=log.v4)
|
|
902
|
-
checkpoint_state =
|
|
933
|
+
checkpoint_state = _torch_load(filename, device=self._device)
|
|
903
934
|
if epoch is None:
|
|
904
935
|
epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
|
|
905
936
|
step = checkpoint_state.get("step", 1)
|
|
@@ -999,7 +1030,7 @@ class Engine(EngineBase):
|
|
|
999
1030
|
print("(No relevant parameters matching.)", file=log.v3)
|
|
1000
1031
|
continue
|
|
1001
1032
|
print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
|
|
1002
|
-
preload_model_state =
|
|
1033
|
+
preload_model_state = _torch_load(opts["filename"], device=self._device)
|
|
1003
1034
|
if opts.get("checkpoint_key", "model") is not None:
|
|
1004
1035
|
# This can be used if an external checkpoint saves a checkpoint a different structure that just the
|
|
1005
1036
|
# model state dict. E.g., if a checkpoint is created using
|
|
@@ -1032,6 +1063,28 @@ class Engine(EngineBase):
|
|
|
1032
1063
|
preload_model_state_keys = set(preload_model_state.keys())
|
|
1033
1064
|
loaded_state_keys.update(preload_model_state.keys())
|
|
1034
1065
|
missing_keys.difference_update(preload_model_state.keys())
|
|
1066
|
+
|
|
1067
|
+
custom_missing_load_func = opts.get("custom_missing_load_func")
|
|
1068
|
+
if custom_missing_load_func:
|
|
1069
|
+
custom_missing_vars_map = {}
|
|
1070
|
+
for var_name in missing_keys_preload:
|
|
1071
|
+
var_shape = self._pt_model.state_dict()[var_name].shape
|
|
1072
|
+
var_val = custom_missing_load_func(
|
|
1073
|
+
name=var_name,
|
|
1074
|
+
shape=var_shape,
|
|
1075
|
+
preload_model_state=preload_model_state,
|
|
1076
|
+
**util.get_fwd_compat_kwargs(),
|
|
1077
|
+
)
|
|
1078
|
+
if var_val is not None:
|
|
1079
|
+
assert var_val.shape == var_shape
|
|
1080
|
+
custom_missing_vars_map[var_name] = var_val
|
|
1081
|
+
preload_model_state.update(custom_missing_vars_map)
|
|
1082
|
+
missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
|
|
1083
|
+
preload_model_state, strict=False
|
|
1084
|
+
)
|
|
1085
|
+
loaded_state_keys.update(preload_model_state.keys())
|
|
1086
|
+
missing_keys.difference_update(preload_model_state.keys())
|
|
1087
|
+
|
|
1035
1088
|
del preload_model_state
|
|
1036
1089
|
gc.collect()
|
|
1037
1090
|
|
|
@@ -1669,3 +1722,15 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
|
|
|
1669
1722
|
p=p,
|
|
1670
1723
|
).item()
|
|
1671
1724
|
)
|
|
1725
|
+
|
|
1726
|
+
|
|
1727
|
+
def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
|
|
1728
|
+
# Might resolve PtCheckpoint or Sisyphus Path objects or so.
|
|
1729
|
+
filename = os.fspath(filename)
|
|
1730
|
+
|
|
1731
|
+
if filename.endswith(".safetensors"):
|
|
1732
|
+
from safetensors.torch import load_file as safetensors_load
|
|
1733
|
+
|
|
1734
|
+
return safetensors_load(filename, device=device)
|
|
1735
|
+
|
|
1736
|
+
return torch.load(filename, map_location=device)
|
|
@@ -1166,20 +1166,29 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1166
1166
|
if start is None:
|
|
1167
1167
|
start = 0
|
|
1168
1168
|
if isinstance(size, Dim):
|
|
1169
|
+
assert end is None
|
|
1169
1170
|
size = size.get_dim_value()
|
|
1170
1171
|
elif isinstance(size, Tensor):
|
|
1172
|
+
assert end is None
|
|
1171
1173
|
assert size.dims == () # scalar
|
|
1172
1174
|
size = size.raw_tensor
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
else:
|
|
1175
|
+
elif isinstance(size, int):
|
|
1176
|
+
pass
|
|
1177
|
+
elif size is None:
|
|
1177
1178
|
if isinstance(end, Tensor):
|
|
1178
1179
|
assert end.dims == ()
|
|
1179
1180
|
end = end.raw_tensor
|
|
1180
|
-
|
|
1181
|
+
elif isinstance(end, int):
|
|
1182
|
+
if end < 0:
|
|
1183
|
+
end += axis.get_dim_value()
|
|
1184
|
+
elif end is None:
|
|
1181
1185
|
end = axis.get_dim_value()
|
|
1182
|
-
|
|
1186
|
+
else:
|
|
1187
|
+
raise TypeError(f"slice: unsupported type for end: {type(end)}")
|
|
1188
|
+
size = end - start
|
|
1189
|
+
else:
|
|
1190
|
+
raise TypeError(f"slice: unsupported type for size: {type(size)}")
|
|
1191
|
+
out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
|
|
1183
1192
|
return out
|
|
1184
1193
|
|
|
1185
1194
|
@staticmethod
|
|
@@ -1572,6 +1581,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1572
1581
|
indices_out_raw = indices_raw % a.dimension
|
|
1573
1582
|
indices_raw = indices_raw // a.dimension
|
|
1574
1583
|
indices = values.copy_template(name=f"top_k_indices_{a.name or i}")
|
|
1584
|
+
indices.feature_dim = None
|
|
1575
1585
|
indices.dtype = TorchBackend.get_dtype_name_raw(indices_out_raw)
|
|
1576
1586
|
indices.sparse_dim = a
|
|
1577
1587
|
indices.raw_tensor = indices_out_raw
|
|
@@ -1588,6 +1598,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1588
1598
|
values = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_values")
|
|
1589
1599
|
values.raw_tensor = values_raw
|
|
1590
1600
|
indices = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_indices")
|
|
1601
|
+
indices.feature_dim = None
|
|
1591
1602
|
indices.dtype = TorchBackend.get_dtype_name_raw(indices_raw)
|
|
1592
1603
|
indices.sparse_dim = axis
|
|
1593
1604
|
indices.raw_tensor = indices_raw
|
|
@@ -1639,6 +1650,8 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1639
1650
|
name=f"random_{distribution}", dims=dims, dtype=dtype, sparse_dim=sparse_dim, feature_dim=feature_dim
|
|
1640
1651
|
)
|
|
1641
1652
|
out.raw_tensor = torch.empty(shape, dtype=dtype_, device=device or rf.get_default_device())
|
|
1653
|
+
if out.raw_tensor.device.type == "meta":
|
|
1654
|
+
return out # nothing more to do
|
|
1642
1655
|
assert explicit_state is None # not implemented otherwise
|
|
1643
1656
|
generator = None # using the global default from PT
|
|
1644
1657
|
assert isinstance(static, bool)
|
|
@@ -1787,6 +1800,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1787
1800
|
dims=(out_dim,) + tuple(remaining_dims),
|
|
1788
1801
|
dtype=tensor.dtype,
|
|
1789
1802
|
sparse_dim=tensor.sparse_dim,
|
|
1803
|
+
feature_dim=tensor.feature_dim,
|
|
1790
1804
|
raw_tensor=out_raw,
|
|
1791
1805
|
)
|
|
1792
1806
|
return out, out_dim
|
|
@@ -1915,7 +1929,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1915
1929
|
if not out_spatial_dims:
|
|
1916
1930
|
out_spatial_dims = rf.make_conv_out_spatial_dims(
|
|
1917
1931
|
in_spatial_dims=in_spatial_dims,
|
|
1918
|
-
filter_size=
|
|
1932
|
+
filter_size=filter_size,
|
|
1919
1933
|
strides=strides or 1,
|
|
1920
1934
|
dilation_rate=dilation_rate or 1,
|
|
1921
1935
|
padding=padding,
|
|
@@ -2028,6 +2042,104 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
2028
2042
|
out.feature_dim = out_dim
|
|
2029
2043
|
return out, out_spatial_dims
|
|
2030
2044
|
|
|
2045
|
+
# noinspection PyShadowingBuiltins
|
|
2046
|
+
@staticmethod
|
|
2047
|
+
def transposed_conv(
|
|
2048
|
+
source: Tensor,
|
|
2049
|
+
*,
|
|
2050
|
+
in_dim: Dim,
|
|
2051
|
+
out_dim: Dim,
|
|
2052
|
+
in_spatial_dims: Sequence[Dim],
|
|
2053
|
+
out_spatial_dims: Optional[Sequence[Dim]] = None,
|
|
2054
|
+
filter: Tensor,
|
|
2055
|
+
filter_size: Sequence[Dim],
|
|
2056
|
+
padding: str,
|
|
2057
|
+
remove_padding: Union[Sequence[int], int] = 0,
|
|
2058
|
+
output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
|
|
2059
|
+
strides: Optional[Sequence[int]] = None,
|
|
2060
|
+
bias: Optional[Tensor] = None,
|
|
2061
|
+
) -> Tuple[Tensor, Sequence[Dim]]:
|
|
2062
|
+
"""transposed convolution"""
|
|
2063
|
+
if not out_spatial_dims:
|
|
2064
|
+
out_spatial_dims = rf.make_transposed_conv_out_spatial_dims(
|
|
2065
|
+
in_spatial_dims=in_spatial_dims,
|
|
2066
|
+
filter_size=filter_size,
|
|
2067
|
+
strides=strides,
|
|
2068
|
+
padding=padding,
|
|
2069
|
+
output_padding=output_padding,
|
|
2070
|
+
)
|
|
2071
|
+
assert remove_padding == 0 # not implemented yet otherwise...
|
|
2072
|
+
if strides is None:
|
|
2073
|
+
strides = [fs.dimension for fs in filter_size]
|
|
2074
|
+
filter_dims = (in_dim, out_dim) + tuple(filter_size)
|
|
2075
|
+
filter = filter.copy_transpose(filter_dims)
|
|
2076
|
+
batch_dims = [d for d in source.dims if d not in (in_dim,) + tuple(in_spatial_dims)]
|
|
2077
|
+
# Torch conv expects (N,C,<spatial dims>) as shape.
|
|
2078
|
+
source = source.copy_transpose(batch_dims + [in_dim] + list(in_spatial_dims))
|
|
2079
|
+
if len(batch_dims) == 1:
|
|
2080
|
+
src_raw = source.raw_tensor
|
|
2081
|
+
else:
|
|
2082
|
+
src_raw = torch.reshape(
|
|
2083
|
+
source.raw_tensor,
|
|
2084
|
+
# potentially merge batch dims all together
|
|
2085
|
+
[-1, in_dim.get_dim_value()] + [d.get_dim_value() for d in in_spatial_dims],
|
|
2086
|
+
)
|
|
2087
|
+
if padding == "same":
|
|
2088
|
+
raise NotImplementedError("transposed_conv with padding='same' not implemented")
|
|
2089
|
+
if padding == "valid":
|
|
2090
|
+
padding_val = 0
|
|
2091
|
+
else:
|
|
2092
|
+
raise ValueError(f"invalid padding {padding!r}, expected 'same' or 'valid'")
|
|
2093
|
+
if len(filter_size) == 1:
|
|
2094
|
+
out_raw = torch.nn.functional.conv_transpose1d(
|
|
2095
|
+
src_raw,
|
|
2096
|
+
weight=filter.raw_tensor,
|
|
2097
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2098
|
+
stride=strides,
|
|
2099
|
+
padding=padding_val,
|
|
2100
|
+
output_padding=output_padding or 0,
|
|
2101
|
+
)
|
|
2102
|
+
elif len(filter_size) == 2:
|
|
2103
|
+
out_raw = torch.nn.functional.conv_transpose2d(
|
|
2104
|
+
src_raw,
|
|
2105
|
+
weight=filter.raw_tensor,
|
|
2106
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2107
|
+
stride=strides,
|
|
2108
|
+
padding=padding_val,
|
|
2109
|
+
output_padding=output_padding or 0,
|
|
2110
|
+
)
|
|
2111
|
+
elif len(filter_size) == 3:
|
|
2112
|
+
out_raw = torch.nn.functional.conv_transpose3d(
|
|
2113
|
+
src_raw,
|
|
2114
|
+
weight=filter.raw_tensor,
|
|
2115
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2116
|
+
stride=strides,
|
|
2117
|
+
padding=padding_val,
|
|
2118
|
+
output_padding=output_padding or 0,
|
|
2119
|
+
)
|
|
2120
|
+
else:
|
|
2121
|
+
raise ValueError(f"invalid number of filter dims {filter_size}, expected 1, 2, or 3")
|
|
2122
|
+
if remove_padding:
|
|
2123
|
+
if isinstance(remove_padding, int):
|
|
2124
|
+
remove_padding = [remove_padding] * len(out_spatial_dims)
|
|
2125
|
+
assert len(remove_padding) == len(out_spatial_dims)
|
|
2126
|
+
slices = [slice(None)] * out_raw.ndim
|
|
2127
|
+
for i, pad in enumerate(remove_padding):
|
|
2128
|
+
if pad > 0:
|
|
2129
|
+
slices[2 + i] = slice(0, -pad)
|
|
2130
|
+
out_raw = out_raw[tuple(slices)]
|
|
2131
|
+
out = Tensor(
|
|
2132
|
+
"transposed_conv",
|
|
2133
|
+
dims=batch_dims + [out_dim] + list(out_spatial_dims),
|
|
2134
|
+
dtype=TorchBackend.get_dtype_name_raw(out_raw),
|
|
2135
|
+
)
|
|
2136
|
+
if len(batch_dims) == 1:
|
|
2137
|
+
out.raw_tensor = out_raw
|
|
2138
|
+
else:
|
|
2139
|
+
out.raw_tensor = torch.reshape(out_raw, [d.get_dim_value() for d in out.dims])
|
|
2140
|
+
out.feature_dim = out_dim
|
|
2141
|
+
return out, out_spatial_dims
|
|
2142
|
+
|
|
2031
2143
|
@staticmethod
|
|
2032
2144
|
def pool(
|
|
2033
2145
|
source: Tensor,
|
|
@@ -8,6 +8,10 @@ import os
|
|
|
8
8
|
import sys
|
|
9
9
|
import gc
|
|
10
10
|
import subprocess
|
|
11
|
+
import signal
|
|
12
|
+
import time
|
|
13
|
+
import contextlib
|
|
14
|
+
import multiprocessing
|
|
11
15
|
import torch
|
|
12
16
|
from returnn.util.better_exchook import better_exchook
|
|
13
17
|
from returnn.util.basic import human_bytes_size
|
|
@@ -26,36 +30,39 @@ def print_available_devices(*, file: Optional[TextIO] = None):
|
|
|
26
30
|
print("CUDA_VISIBLE_DEVICES is set to %r." % os.environ["CUDA_VISIBLE_DEVICES"], file=file)
|
|
27
31
|
cuda_visible_devs = dict(enumerate([int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",") if d]))
|
|
28
32
|
else:
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
33
|
+
with timeout("torch.cuda.is_available()"):
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
print("CUDA_VISIBLE_DEVICES is not set.", file=file)
|
|
36
|
+
|
|
37
|
+
with timeout("torch.cuda.is_available()"):
|
|
38
|
+
if not torch.cuda.is_available():
|
|
39
|
+
print("(CUDA not available)", file=file)
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
print("Available CUDA devices:", file=file)
|
|
43
|
+
count = torch.cuda.device_count()
|
|
44
|
+
if cuda_visible_devs is not None and len(cuda_visible_devs) != count:
|
|
45
|
+
print(
|
|
46
|
+
f"(Mismatch between CUDA device count {count}"
|
|
47
|
+
f" and CUDA_VISIBLE_DEVICES {cuda_visible_devs} count {len(cuda_visible_devs)}?)",
|
|
48
|
+
file=file,
|
|
49
|
+
)
|
|
50
|
+
for i in range(count):
|
|
51
|
+
print(f" {i + 1}/{count}: cuda:{i}", file=file)
|
|
52
|
+
props = torch.cuda.get_device_properties(i)
|
|
53
|
+
print(f" name: {props.name}", file=file)
|
|
54
|
+
print(f" total_memory: {human_bytes_size(props.total_memory)}", file=file)
|
|
55
|
+
print(f" capability: {props.major}.{props.minor}", file=file)
|
|
56
|
+
if cuda_visible_devs is not None:
|
|
57
|
+
if len(cuda_visible_devs) == count:
|
|
58
|
+
dev_idx_s = cuda_visible_devs[i]
|
|
52
59
|
else:
|
|
53
|
-
dev_idx_s =
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
print("(
|
|
60
|
+
dev_idx_s = "?"
|
|
61
|
+
else:
|
|
62
|
+
dev_idx_s = i
|
|
63
|
+
print(f" device_index: {dev_idx_s}", file=file)
|
|
64
|
+
if not count:
|
|
65
|
+
print(" (None)", file=file)
|
|
59
66
|
|
|
60
67
|
|
|
61
68
|
def print_using_cuda_device_report(dev: Union[str, torch.device], *, file: Optional[TextIO] = None):
|
|
@@ -108,7 +115,7 @@ def diagnose_no_gpu() -> List[str]:
|
|
|
108
115
|
except Exception as exc:
|
|
109
116
|
print("nvidia-smi failed:", exc)
|
|
110
117
|
better_exchook(*sys.exc_info(), debugshell=False)
|
|
111
|
-
res.append(
|
|
118
|
+
res.append("nvidia-smi failed")
|
|
112
119
|
|
|
113
120
|
return res
|
|
114
121
|
|
|
@@ -152,4 +159,31 @@ def garbage_collect():
|
|
|
152
159
|
f"alloc {human_bytes_size(torch.cuda.memory_allocated())}",
|
|
153
160
|
f"reserved {human_bytes_size(torch.cuda.memory_reserved())}",
|
|
154
161
|
]
|
|
155
|
-
print(
|
|
162
|
+
print("CUDA memory usage after triggered GC:", " ".join(stats))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@contextlib.contextmanager
|
|
166
|
+
def timeout(info: str, *, seconds: int = 30):
|
|
167
|
+
"""
|
|
168
|
+
Note: don't use signal handlers (e.g. signal.alarm) because unfortunately
|
|
169
|
+
potential hanging funcs will block the main thread and thus block the signal handler from executing.
|
|
170
|
+
Thus, we use a subprocess.
|
|
171
|
+
|
|
172
|
+
:param seconds:
|
|
173
|
+
:param info:
|
|
174
|
+
"""
|
|
175
|
+
proc = multiprocessing.Process(
|
|
176
|
+
target=_timeout_handler, kwargs={"seconds": seconds, "proc_id": os.getpid(), "info": info}
|
|
177
|
+
)
|
|
178
|
+
proc.start()
|
|
179
|
+
try:
|
|
180
|
+
yield
|
|
181
|
+
finally:
|
|
182
|
+
proc.terminate()
|
|
183
|
+
proc.join()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _timeout_handler(*, seconds: Union[float, int], proc_id: int, info: str):
|
|
187
|
+
time.sleep(seconds)
|
|
188
|
+
print(f"ERROR: {info}: Timeout handler after {seconds} seconds, killing proc {proc_id}.", file=sys.stderr)
|
|
189
|
+
os.kill(proc_id, signal.SIGABRT)
|
|
@@ -71,7 +71,13 @@ def help_on_torch_exception(
|
|
|
71
71
|
if not count_frames:
|
|
72
72
|
exc_ext.append("(No module call frames.)")
|
|
73
73
|
|
|
74
|
-
if
|
|
74
|
+
if (
|
|
75
|
+
# KeyError formatting would be wrong, showing `KeyError: "enc_spatial_dim\n\nStep idx: 0\..."`
|
|
76
|
+
not isinstance(exc, KeyError)
|
|
77
|
+
and len(exc.args) == 1
|
|
78
|
+
and isinstance(exc.args[0], str)
|
|
79
|
+
and not always_direct_print
|
|
80
|
+
):
|
|
75
81
|
exc.args = ("\n".join([exc.args[0], ""] + exc_ext),)
|
|
76
82
|
else:
|
|
77
83
|
for msg in exc_ext:
|
returnn/util/basic.py
CHANGED
|
@@ -365,12 +365,9 @@ def get_checkpoint_filepattern(filepath):
|
|
|
365
365
|
:return: CheckpointLoader compatible filepattern
|
|
366
366
|
:rtype: str
|
|
367
367
|
"""
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
return filepath[: -len(".index")]
|
|
372
|
-
elif filepath.endswith(".pt"):
|
|
373
|
-
return filepath[: -len(".pt")]
|
|
368
|
+
for ext in [".meta", ".index", ".pt"]:
|
|
369
|
+
if filepath.endswith(ext):
|
|
370
|
+
return filepath[: -len(ext)]
|
|
374
371
|
return filepath
|
|
375
372
|
|
|
376
373
|
|
|
@@ -557,7 +554,9 @@ def get_tensorflow_version_tuple() -> Tuple[int, ...]:
|
|
|
557
554
|
import tensorflow as tf # noqa
|
|
558
555
|
import re
|
|
559
556
|
|
|
560
|
-
|
|
557
|
+
# Remove unwanted suffixes from the TF version string (e.g. "2.20.0-dev0+selfbuilt")
|
|
558
|
+
filtered_version = [re.sub("(-rc[0-9]|-dev[0-9]*)(\\+selfbuilt)?", "", s) for s in tf.__version__.split(".")]
|
|
559
|
+
return tuple(int(v) for v in filtered_version)
|
|
561
560
|
|
|
562
561
|
|
|
563
562
|
class ReportImportedDevModules:
|
returnn/util/better_exchook.py
CHANGED
|
@@ -1093,6 +1093,7 @@ def format_tb(
|
|
|
1093
1093
|
with_color=None,
|
|
1094
1094
|
with_vars=None,
|
|
1095
1095
|
clear_frames=True,
|
|
1096
|
+
colorize=None,
|
|
1096
1097
|
):
|
|
1097
1098
|
"""
|
|
1098
1099
|
Formats a traceback into a list of strings, each corresponding to one frame.
|
|
@@ -1110,11 +1111,14 @@ def format_tb(
|
|
|
1110
1111
|
That will potentially fix some mem leaks regarding locals, so it can be important.
|
|
1111
1112
|
Also see https://github.com/python/cpython/issues/113939.
|
|
1112
1113
|
However, any further access to frame locals will not work (e.g., if you want to use a debugger afterward).
|
|
1114
|
+
:param colorize: for compat with Python >=3.13, currently ignored
|
|
1113
1115
|
:return: list of strings, each corresponding to one frame in the traceback.
|
|
1114
1116
|
Each string contains the file name, line number, function name, source code line, maybe relevant variables,
|
|
1115
1117
|
etc., and a final newline.
|
|
1116
1118
|
:rtype: list[str]
|
|
1117
1119
|
"""
|
|
1120
|
+
if colorize is not None and with_color is None:
|
|
1121
|
+
with_color = colorize
|
|
1118
1122
|
color = Color(enable=with_color)
|
|
1119
1123
|
output = _OutputLinesCollector(color=color)
|
|
1120
1124
|
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Customized (derived) dict to pass as ``collected_outputs`` to some of the RF modules,
|
|
3
|
+
or potential other use cases.
|
|
4
|
+
|
|
5
|
+
You can predefine (by pattern) what kind of outputs you want to collect and store in this dict.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Union, Sequence
|
|
9
|
+
import fnmatch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CollectOutputsDict(dict):
|
|
13
|
+
"""
|
|
14
|
+
Customized (derived) dict, where you can predefine (by key pattern)
|
|
15
|
+
what kind of keys you want to collect and store in this dict.
|
|
16
|
+
Other keys will be ignored.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, *args, allowed_key_patterns: Optional[Sequence[str]] = None, **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
Initialize the CollectOutputsDict.
|
|
22
|
+
|
|
23
|
+
:param allowed_key_patterns:
|
|
24
|
+
List of key patterns (with wildcards) that are allowed to be stored in the dict.
|
|
25
|
+
If None, all keys are allowed.
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(*args, **kwargs)
|
|
28
|
+
self.allowed_key_patterns = allowed_key_patterns
|
|
29
|
+
|
|
30
|
+
def __setitem__(self, key, value):
|
|
31
|
+
"""
|
|
32
|
+
Set an item in the dict if the key matches allowed patterns.
|
|
33
|
+
"""
|
|
34
|
+
if self.is_key_allowed(key):
|
|
35
|
+
super().__setitem__(key, value)
|
|
36
|
+
|
|
37
|
+
def setdefault(self, key, default=None):
|
|
38
|
+
"""
|
|
39
|
+
Set default value for a key if it matches allowed patterns.
|
|
40
|
+
"""
|
|
41
|
+
if self.is_key_allowed(key):
|
|
42
|
+
return super().setdefault(key, default)
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
def update(self, mapping, **kwargs):
|
|
46
|
+
"""
|
|
47
|
+
Update the dict with another mapping, only adding allowed keys.
|
|
48
|
+
"""
|
|
49
|
+
assert not kwargs
|
|
50
|
+
for key, value in mapping.items():
|
|
51
|
+
if self.is_key_allowed(key):
|
|
52
|
+
super().__setitem__(key, value)
|
|
53
|
+
|
|
54
|
+
def is_key_allowed(self, key: str) -> bool:
|
|
55
|
+
"""
|
|
56
|
+
Check if the key matches any of the allowed patterns.
|
|
57
|
+
|
|
58
|
+
:param key:
|
|
59
|
+
:return: True if the key is allowed, False otherwise.
|
|
60
|
+
"""
|
|
61
|
+
if self.allowed_key_patterns is None:
|
|
62
|
+
return True # If no patterns defined, allow all keys
|
|
63
|
+
for pattern in self.allowed_key_patterns:
|
|
64
|
+
if fnmatch.fnmatch(key, pattern):
|
|
65
|
+
return True
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def is_key_allowed_in_collect_outputs_dict(collect_outputs: Union[CollectOutputsDict, dict], key: str) -> bool:
|
|
70
|
+
"""
|
|
71
|
+
Check if a key is allowed in the given CollectOutputsDict.
|
|
72
|
+
|
|
73
|
+
:param collect_outputs:
|
|
74
|
+
:param key:
|
|
75
|
+
:return: True if the key is allowed, False otherwise.
|
|
76
|
+
"""
|
|
77
|
+
if isinstance(collect_outputs, CollectOutputsDict):
|
|
78
|
+
return collect_outputs.is_key_allowed(key)
|
|
79
|
+
return True # If it's a regular dict, all keys are allowed
|
returnn/util/debug.py
CHANGED
|
@@ -704,7 +704,7 @@ def check_py_traces_rf_to_pt_equal(
|
|
|
704
704
|
"""
|
|
705
705
|
import random
|
|
706
706
|
import torch
|
|
707
|
-
from returnn.tensor import
|
|
707
|
+
from returnn.tensor import Dim
|
|
708
708
|
import returnn.frontend as rf
|
|
709
709
|
|
|
710
710
|
# noinspection PyProtectedMember
|
|
@@ -715,9 +715,18 @@ def check_py_traces_rf_to_pt_equal(
|
|
|
715
715
|
def _get_entry(trace, func, i, name, j):
|
|
716
716
|
return trace[func][i][name][j]
|
|
717
717
|
|
|
718
|
+
def _get_entry_attr(trace, func, i, name, j):
|
|
719
|
+
name, attr = name.split(".", 1)
|
|
720
|
+
obj = trace[func][i][name][j]
|
|
721
|
+
return eval(f"{name}.{attr}", {name: obj})
|
|
722
|
+
|
|
718
723
|
def _resolve_dim(dim: Union[Dim, str]) -> Dim:
|
|
719
724
|
if isinstance(dim, Dim):
|
|
720
725
|
return dim
|
|
726
|
+
elif isinstance(dim, str) and "." in dim:
|
|
727
|
+
dim = _get_entry_attr(trace_rf, *check_rf[:2], dim, -1)
|
|
728
|
+
assert isinstance(dim, Dim)
|
|
729
|
+
return dim
|
|
721
730
|
elif isinstance(dim, str):
|
|
722
731
|
dim = _get_entry(trace_rf, *check_rf[:2], dim, -1)
|
|
723
732
|
assert isinstance(dim, Dim)
|
|
@@ -763,7 +772,7 @@ def check_py_traces_rf_to_pt_equal(
|
|
|
763
772
|
if len(indices) > 5:
|
|
764
773
|
msgs.append(" non-matching ...")
|
|
765
774
|
non_matching.append("\n".join(msgs_prefix + msgs))
|
|
766
|
-
print(
|
|
775
|
+
print(" mismatch!")
|
|
767
776
|
for msg in msgs:
|
|
768
777
|
print(msg)
|
|
769
778
|
|