returnn 1.20251027.224345__py3-none-any.whl → 1.20260113.134416__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +2 -2
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/config.py +1 -1
- returnn/datasets/lm.py +130 -42
- returnn/datasets/meta.py +93 -43
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +48 -2
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +40 -1
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +80 -0
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +7 -4
- returnn/tf/frontend_layers/_backend.py +4 -3
- returnn/tf/layers/basic.py +15 -39
- returnn/tf/native_op.py +11 -58
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +19 -0
- returnn/torch/engine.py +157 -6
- returnn/torch/frontend/_backend.py +137 -15
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/util/basic.py +5 -6
- returnn/util/better_exchook.py +4 -0
- returnn/util/debug.py +12 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/task_system.py +1 -1
- {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/METADATA +2 -2
- {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/RECORD +43 -42
- {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/LICENSE +0 -0
- {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/WHEEL +0 -0
- {returnn-1.20251027.224345.dist-info → returnn-1.20260113.134416.dist-info}/top_level.txt +0 -0
returnn/torch/engine.py
CHANGED
|
@@ -3,9 +3,11 @@ Main engine for PyTorch
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
+
|
|
6
7
|
from typing import Optional, Any, Union, Callable, Dict, Set
|
|
7
8
|
from contextlib import nullcontext, ExitStack, contextmanager
|
|
8
9
|
|
|
10
|
+
import sys
|
|
9
11
|
import gc
|
|
10
12
|
import os
|
|
11
13
|
import time
|
|
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
|
|
|
20
22
|
from torch.utils.data import DataLoader
|
|
21
23
|
from torch import autocast
|
|
22
24
|
from torch.cuda import amp
|
|
25
|
+
from torch.profiler import record_function
|
|
23
26
|
import numpy as np
|
|
24
27
|
|
|
25
28
|
import returnn
|
|
@@ -404,10 +407,14 @@ class Engine(EngineBase):
|
|
|
404
407
|
total_data_size_packed = NumbersDict()
|
|
405
408
|
total_data_size_padded = NumbersDict()
|
|
406
409
|
|
|
410
|
+
prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
|
|
411
|
+
if prof:
|
|
412
|
+
prof.__enter__()
|
|
413
|
+
|
|
407
414
|
report_prefix = f"ep {self.epoch} train"
|
|
408
415
|
try:
|
|
409
416
|
while True:
|
|
410
|
-
with torch.no_grad():
|
|
417
|
+
with torch.no_grad(), record_function("data_loading"):
|
|
411
418
|
extern_data_raw = next(data_iter, None)
|
|
412
419
|
|
|
413
420
|
step_begin_time = time.monotonic()
|
|
@@ -485,7 +492,8 @@ class Engine(EngineBase):
|
|
|
485
492
|
with (
|
|
486
493
|
self._ddp_pt_model.no_sync()
|
|
487
494
|
if (self._ddp_pt_model is not None and not perform_update_step)
|
|
488
|
-
else nullcontext()
|
|
495
|
+
else nullcontext(),
|
|
496
|
+
record_function("backward"),
|
|
489
497
|
):
|
|
490
498
|
if self._grad_scaler is not None:
|
|
491
499
|
self._grad_scaler.scale(total_loss.raw_tensor).backward()
|
|
@@ -500,7 +508,8 @@ class Engine(EngineBase):
|
|
|
500
508
|
|
|
501
509
|
# only update the weights when every gradient accumulation loop ends
|
|
502
510
|
if perform_update_step:
|
|
503
|
-
|
|
511
|
+
with record_function("optimizer_step"):
|
|
512
|
+
self._updater.step(grad_scaler=self._grad_scaler)
|
|
504
513
|
zero_grad_next_step = perform_update_step
|
|
505
514
|
|
|
506
515
|
if self._torch_distributed_ctx:
|
|
@@ -532,7 +541,7 @@ class Engine(EngineBase):
|
|
|
532
541
|
for key, val in eval_info.items():
|
|
533
542
|
self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
|
|
534
543
|
self._tensorboard_writer.add_scalar(
|
|
535
|
-
|
|
544
|
+
"train/learning_rate",
|
|
536
545
|
self._updater.get_effective_learning_rate(),
|
|
537
546
|
global_step=self.global_train_step,
|
|
538
547
|
)
|
|
@@ -582,10 +591,19 @@ class Engine(EngineBase):
|
|
|
582
591
|
self._updater.set_current_train_step(
|
|
583
592
|
global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
|
|
584
593
|
)
|
|
594
|
+
|
|
595
|
+
if prof:
|
|
596
|
+
prof.step()
|
|
597
|
+
|
|
585
598
|
except Exception as exc:
|
|
599
|
+
if prof:
|
|
600
|
+
prof.__exit__(type(exc), exc, exc.__traceback__)
|
|
586
601
|
help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
|
|
587
602
|
raise
|
|
588
603
|
|
|
604
|
+
if prof:
|
|
605
|
+
prof.__exit__(None, None, None)
|
|
606
|
+
|
|
589
607
|
elapsed = time.monotonic() - epoch_start_time
|
|
590
608
|
elapsed_computation_percentage = elapsed_computation_time / elapsed
|
|
591
609
|
total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
|
|
@@ -885,6 +903,7 @@ class Engine(EngineBase):
|
|
|
885
903
|
if self._default_float_dtype:
|
|
886
904
|
stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
|
|
887
905
|
stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
|
|
906
|
+
stack.enter_context(record_function("model_step"))
|
|
888
907
|
yield
|
|
889
908
|
|
|
890
909
|
def _run_step(
|
|
@@ -930,7 +949,7 @@ class Engine(EngineBase):
|
|
|
930
949
|
if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
|
|
931
950
|
filename = model_epoch_filename
|
|
932
951
|
print("Load model %s" % (filename,), file=log.v4)
|
|
933
|
-
checkpoint_state =
|
|
952
|
+
checkpoint_state = _torch_load(filename, device=self._device)
|
|
934
953
|
if epoch is None:
|
|
935
954
|
epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
|
|
936
955
|
step = checkpoint_state.get("step", 1)
|
|
@@ -1030,7 +1049,7 @@ class Engine(EngineBase):
|
|
|
1030
1049
|
print("(No relevant parameters matching.)", file=log.v3)
|
|
1031
1050
|
continue
|
|
1032
1051
|
print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
|
|
1033
|
-
preload_model_state =
|
|
1052
|
+
preload_model_state = _torch_load(opts["filename"], device=self._device)
|
|
1034
1053
|
if opts.get("checkpoint_key", "model") is not None:
|
|
1035
1054
|
# This can be used if an external checkpoint saves a checkpoint a different structure that just the
|
|
1036
1055
|
# model state dict. E.g., if a checkpoint is created using
|
|
@@ -1063,6 +1082,28 @@ class Engine(EngineBase):
|
|
|
1063
1082
|
preload_model_state_keys = set(preload_model_state.keys())
|
|
1064
1083
|
loaded_state_keys.update(preload_model_state.keys())
|
|
1065
1084
|
missing_keys.difference_update(preload_model_state.keys())
|
|
1085
|
+
|
|
1086
|
+
custom_missing_load_func = opts.get("custom_missing_load_func")
|
|
1087
|
+
if custom_missing_load_func:
|
|
1088
|
+
custom_missing_vars_map = {}
|
|
1089
|
+
for var_name in missing_keys_preload:
|
|
1090
|
+
var_shape = self._pt_model.state_dict()[var_name].shape
|
|
1091
|
+
var_val = custom_missing_load_func(
|
|
1092
|
+
name=var_name,
|
|
1093
|
+
shape=var_shape,
|
|
1094
|
+
preload_model_state=preload_model_state,
|
|
1095
|
+
**util.get_fwd_compat_kwargs(),
|
|
1096
|
+
)
|
|
1097
|
+
if var_val is not None:
|
|
1098
|
+
assert var_val.shape == var_shape
|
|
1099
|
+
custom_missing_vars_map[var_name] = var_val
|
|
1100
|
+
preload_model_state.update(custom_missing_vars_map)
|
|
1101
|
+
missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
|
|
1102
|
+
preload_model_state, strict=False
|
|
1103
|
+
)
|
|
1104
|
+
loaded_state_keys.update(preload_model_state.keys())
|
|
1105
|
+
missing_keys.difference_update(preload_model_state.keys())
|
|
1106
|
+
|
|
1066
1107
|
del preload_model_state
|
|
1067
1108
|
gc.collect()
|
|
1068
1109
|
|
|
@@ -1700,3 +1741,113 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
|
|
|
1700
1741
|
p=p,
|
|
1701
1742
|
).item()
|
|
1702
1743
|
)
|
|
1744
|
+
|
|
1745
|
+
|
|
1746
|
+
def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
|
|
1747
|
+
# Might resolve PtCheckpoint or Sisyphus Path objects or so.
|
|
1748
|
+
filename = os.fspath(filename)
|
|
1749
|
+
|
|
1750
|
+
if filename.endswith(".safetensors"):
|
|
1751
|
+
from safetensors.torch import load_file as safetensors_load
|
|
1752
|
+
|
|
1753
|
+
return safetensors_load(filename, device=device)
|
|
1754
|
+
|
|
1755
|
+
return torch.load(filename, map_location=device)
|
|
1756
|
+
|
|
1757
|
+
|
|
1758
|
+
class _TorchProfiler:
|
|
1759
|
+
def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
|
|
1760
|
+
self.profiler = profiler
|
|
1761
|
+
self.max_step = max_step
|
|
1762
|
+
self.entered = False
|
|
1763
|
+
|
|
1764
|
+
def __enter__(self):
|
|
1765
|
+
self.profiler.__enter__()
|
|
1766
|
+
self.entered = True
|
|
1767
|
+
|
|
1768
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1769
|
+
if not self.entered:
|
|
1770
|
+
return
|
|
1771
|
+
self.entered = False
|
|
1772
|
+
self.profiler.__exit__(exc_type, exc_val, exc_tb)
|
|
1773
|
+
|
|
1774
|
+
if exc_type is None:
|
|
1775
|
+
print(
|
|
1776
|
+
"Torch profiling finished, exporting Chrome trace to torch_profile.json,"
|
|
1777
|
+
" memory timeline to torch_memory_profile.html...",
|
|
1778
|
+
file=log.v2,
|
|
1779
|
+
)
|
|
1780
|
+
self.profiler.export_chrome_trace("torch_profile.json")
|
|
1781
|
+
self.profiler.export_memory_timeline("torch_memory_profile.html")
|
|
1782
|
+
|
|
1783
|
+
print("Exiting program after Torch profiling.", file=log.v2)
|
|
1784
|
+
sys.exit(0)
|
|
1785
|
+
|
|
1786
|
+
def step(self):
|
|
1787
|
+
"""step"""
|
|
1788
|
+
self.profiler.step()
|
|
1789
|
+
if self.max_step is not None and self.profiler.step_num > self.max_step:
|
|
1790
|
+
print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
|
|
1791
|
+
self.profiler.stop()
|
|
1792
|
+
self.__exit__(None, None, None)
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
def _opt_torch_profiler_from_opts(
|
|
1796
|
+
opts: Union[None, int, bool, str, Dict[str, Any]],
|
|
1797
|
+
) -> Optional[_TorchProfiler]:
|
|
1798
|
+
if isinstance(opts, str):
|
|
1799
|
+
from returnn.util.basic import to_bool
|
|
1800
|
+
|
|
1801
|
+
opts = to_bool(opts)
|
|
1802
|
+
|
|
1803
|
+
if opts is None:
|
|
1804
|
+
return None
|
|
1805
|
+
elif isinstance(opts, (bool, int)):
|
|
1806
|
+
if not opts:
|
|
1807
|
+
return None
|
|
1808
|
+
opts = {}
|
|
1809
|
+
elif isinstance(opts, dict):
|
|
1810
|
+
opts = opts.copy()
|
|
1811
|
+
else:
|
|
1812
|
+
raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
|
|
1813
|
+
|
|
1814
|
+
from torch.profiler import profile, ProfilerActivity, schedule
|
|
1815
|
+
|
|
1816
|
+
print("Using Torch profiler...", file=log.v2)
|
|
1817
|
+
|
|
1818
|
+
prof_max_step = None
|
|
1819
|
+
|
|
1820
|
+
if "activities" not in opts:
|
|
1821
|
+
activities = [ProfilerActivity.CPU]
|
|
1822
|
+
if torch.cuda.is_available():
|
|
1823
|
+
activities += [ProfilerActivity.CUDA]
|
|
1824
|
+
elif torch.xpu.is_available():
|
|
1825
|
+
activities += [ProfilerActivity.XPU]
|
|
1826
|
+
opts["activities"] = activities
|
|
1827
|
+
|
|
1828
|
+
opts.setdefault("profile_memory", True)
|
|
1829
|
+
opts.setdefault("record_shapes", True)
|
|
1830
|
+
opts.setdefault("with_stack", True)
|
|
1831
|
+
opts.setdefault("with_flops", True)
|
|
1832
|
+
# Note: active*repeat are the steps we actually profile.
|
|
1833
|
+
opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
|
|
1834
|
+
|
|
1835
|
+
if isinstance(opts["schedule"], dict):
|
|
1836
|
+
schedule_opts: Dict[str, Any] = opts["schedule"]
|
|
1837
|
+
schedule_opts = schedule_opts.copy()
|
|
1838
|
+
schedule_opts.setdefault("repeat", 0)
|
|
1839
|
+
schedule_opts.setdefault("skip_first", 0)
|
|
1840
|
+
schedule_opts.setdefault("skip_first_wait", 0)
|
|
1841
|
+
opts["schedule"] = schedule(**schedule_opts)
|
|
1842
|
+
|
|
1843
|
+
if schedule_opts["repeat"] > 0:
|
|
1844
|
+
prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
|
|
1845
|
+
"repeat"
|
|
1846
|
+
]
|
|
1847
|
+
prof_max_step += schedule_opts["skip_first"]
|
|
1848
|
+
if schedule_opts["skip_first_wait"] != 0:
|
|
1849
|
+
prof_max_step -= schedule_opts["wait"]
|
|
1850
|
+
print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
|
|
1851
|
+
|
|
1852
|
+
prof = profile(**opts)
|
|
1853
|
+
return _TorchProfiler(prof, prof_max_step)
|
|
@@ -275,7 +275,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
275
275
|
:return: tensor
|
|
276
276
|
"""
|
|
277
277
|
assert len(dims) >= 2
|
|
278
|
-
first_axis = min(source.dims.index(d) for d in dims)
|
|
278
|
+
first_axis = min([source.dims.index(d) for d in dims])
|
|
279
279
|
pre_dims = source.dims[:first_axis]
|
|
280
280
|
post_dims = [d for d in source.dims if d not in dims and d not in pre_dims]
|
|
281
281
|
source = source.copy_transpose(tuple(pre_dims) + tuple(dims) + tuple(post_dims), allow_int=False)
|
|
@@ -884,7 +884,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
884
884
|
:param perm: e.g. [0, 2, 1]
|
|
885
885
|
:return: permuted (transposed) raw tensor; wraps torch.permute
|
|
886
886
|
"""
|
|
887
|
-
if all(p == i for i, p in enumerate(perm)):
|
|
887
|
+
if all([p == i for i, p in enumerate(perm)]):
|
|
888
888
|
return raw_tensor
|
|
889
889
|
return torch.permute(raw_tensor, tuple(perm))
|
|
890
890
|
|
|
@@ -1166,20 +1166,29 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1166
1166
|
if start is None:
|
|
1167
1167
|
start = 0
|
|
1168
1168
|
if isinstance(size, Dim):
|
|
1169
|
+
assert end is None
|
|
1169
1170
|
size = size.get_dim_value()
|
|
1170
1171
|
elif isinstance(size, Tensor):
|
|
1172
|
+
assert end is None
|
|
1171
1173
|
assert size.dims == () # scalar
|
|
1172
1174
|
size = size.raw_tensor
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
else:
|
|
1175
|
+
elif isinstance(size, int):
|
|
1176
|
+
pass
|
|
1177
|
+
elif size is None:
|
|
1177
1178
|
if isinstance(end, Tensor):
|
|
1178
1179
|
assert end.dims == ()
|
|
1179
1180
|
end = end.raw_tensor
|
|
1180
|
-
|
|
1181
|
+
elif isinstance(end, int):
|
|
1182
|
+
if end < 0:
|
|
1183
|
+
end += axis.get_dim_value()
|
|
1184
|
+
elif end is None:
|
|
1181
1185
|
end = axis.get_dim_value()
|
|
1182
|
-
|
|
1186
|
+
else:
|
|
1187
|
+
raise TypeError(f"slice: unsupported type for end: {type(end)}")
|
|
1188
|
+
size = end - start
|
|
1189
|
+
else:
|
|
1190
|
+
raise TypeError(f"slice: unsupported type for size: {type(size)}")
|
|
1191
|
+
out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
|
|
1183
1192
|
return out
|
|
1184
1193
|
|
|
1185
1194
|
@staticmethod
|
|
@@ -1352,12 +1361,24 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1352
1361
|
a_dims = a.dims
|
|
1353
1362
|
b_dims = b.dims
|
|
1354
1363
|
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1364
|
+
if not all(dim in a_dims for dim in reduce) or not all(dim in b_dims for dim in reduce):
|
|
1365
|
+
# revert to the generic einsum implementation
|
|
1366
|
+
assert all(dim in a_dims + b_dims for dim in reduce), "Some reduce Dims not in a or b."
|
|
1367
|
+
result_dims = [dim for dim in a_dims if dim not in reduce] + [
|
|
1368
|
+
dim for dim in b_dims if dim not in reduce and dim not in a_dims
|
|
1369
|
+
]
|
|
1370
|
+
map_to_letter = {}
|
|
1371
|
+
for dim in a_dims + b_dims:
|
|
1372
|
+
if dim not in map_to_letter:
|
|
1373
|
+
map_to_letter[dim] = chr(97 + len(map_to_letter)) # 'a', 'b', 'c', ...
|
|
1374
|
+
a_subscript = "".join(map_to_letter[dim] for dim in a_dims)
|
|
1375
|
+
b_subscript = "".join(map_to_letter[dim] for dim in b_dims)
|
|
1376
|
+
out_subscript = "".join(map_to_letter[dim] for dim in result_dims)
|
|
1377
|
+
raw_result = torch.einsum(f"{a_subscript},{b_subscript}->{out_subscript}", a.raw_tensor, b.raw_tensor)
|
|
1378
|
+
result_tensor = Tensor(
|
|
1379
|
+
"einsum", dims=result_dims, raw_tensor=raw_result, dtype=TorchBackend.get_dtype_name_raw(raw_result)
|
|
1380
|
+
)
|
|
1381
|
+
return result_tensor
|
|
1361
1382
|
|
|
1362
1383
|
if len(reduce) > 1:
|
|
1363
1384
|
reduce = list(reduce)
|
|
@@ -1767,6 +1788,9 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1767
1788
|
remaining_dims = [d for d in tensor.dims if d not in mask.dims]
|
|
1768
1789
|
tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
|
|
1769
1790
|
in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
|
|
1791
|
+
if any([in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)]):
|
|
1792
|
+
# unbroadcast
|
|
1793
|
+
in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
|
|
1770
1794
|
if mask.raw_tensor.device.type == "meta":
|
|
1771
1795
|
# This is not supported, but also, we would anyway not know the out shape.
|
|
1772
1796
|
# However, instead of erroring, just assume some dummy mask.
|
|
@@ -1920,7 +1944,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1920
1944
|
if not out_spatial_dims:
|
|
1921
1945
|
out_spatial_dims = rf.make_conv_out_spatial_dims(
|
|
1922
1946
|
in_spatial_dims=in_spatial_dims,
|
|
1923
|
-
filter_size=
|
|
1947
|
+
filter_size=filter_size,
|
|
1924
1948
|
strides=strides or 1,
|
|
1925
1949
|
dilation_rate=dilation_rate or 1,
|
|
1926
1950
|
padding=padding,
|
|
@@ -2033,6 +2057,104 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
2033
2057
|
out.feature_dim = out_dim
|
|
2034
2058
|
return out, out_spatial_dims
|
|
2035
2059
|
|
|
2060
|
+
# noinspection PyShadowingBuiltins
|
|
2061
|
+
@staticmethod
|
|
2062
|
+
def transposed_conv(
|
|
2063
|
+
source: Tensor,
|
|
2064
|
+
*,
|
|
2065
|
+
in_dim: Dim,
|
|
2066
|
+
out_dim: Dim,
|
|
2067
|
+
in_spatial_dims: Sequence[Dim],
|
|
2068
|
+
out_spatial_dims: Optional[Sequence[Dim]] = None,
|
|
2069
|
+
filter: Tensor,
|
|
2070
|
+
filter_size: Sequence[Dim],
|
|
2071
|
+
padding: str,
|
|
2072
|
+
remove_padding: Union[Sequence[int], int] = 0,
|
|
2073
|
+
output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
|
|
2074
|
+
strides: Optional[Sequence[int]] = None,
|
|
2075
|
+
bias: Optional[Tensor] = None,
|
|
2076
|
+
) -> Tuple[Tensor, Sequence[Dim]]:
|
|
2077
|
+
"""transposed convolution"""
|
|
2078
|
+
if not out_spatial_dims:
|
|
2079
|
+
out_spatial_dims = rf.make_transposed_conv_out_spatial_dims(
|
|
2080
|
+
in_spatial_dims=in_spatial_dims,
|
|
2081
|
+
filter_size=filter_size,
|
|
2082
|
+
strides=strides,
|
|
2083
|
+
padding=padding,
|
|
2084
|
+
output_padding=output_padding,
|
|
2085
|
+
)
|
|
2086
|
+
assert remove_padding == 0 # not implemented yet otherwise...
|
|
2087
|
+
if strides is None:
|
|
2088
|
+
strides = [fs.dimension for fs in filter_size]
|
|
2089
|
+
filter_dims = (in_dim, out_dim) + tuple(filter_size)
|
|
2090
|
+
filter = filter.copy_transpose(filter_dims)
|
|
2091
|
+
batch_dims = [d for d in source.dims if d not in (in_dim,) + tuple(in_spatial_dims)]
|
|
2092
|
+
# Torch conv expects (N,C,<spatial dims>) as shape.
|
|
2093
|
+
source = source.copy_transpose(batch_dims + [in_dim] + list(in_spatial_dims))
|
|
2094
|
+
if len(batch_dims) == 1:
|
|
2095
|
+
src_raw = source.raw_tensor
|
|
2096
|
+
else:
|
|
2097
|
+
src_raw = torch.reshape(
|
|
2098
|
+
source.raw_tensor,
|
|
2099
|
+
# potentially merge batch dims all together
|
|
2100
|
+
[-1, in_dim.get_dim_value()] + [d.get_dim_value() for d in in_spatial_dims],
|
|
2101
|
+
)
|
|
2102
|
+
if padding == "same":
|
|
2103
|
+
raise NotImplementedError("transposed_conv with padding='same' not implemented")
|
|
2104
|
+
if padding == "valid":
|
|
2105
|
+
padding_val = 0
|
|
2106
|
+
else:
|
|
2107
|
+
raise ValueError(f"invalid padding {padding!r}, expected 'same' or 'valid'")
|
|
2108
|
+
if len(filter_size) == 1:
|
|
2109
|
+
out_raw = torch.nn.functional.conv_transpose1d(
|
|
2110
|
+
src_raw,
|
|
2111
|
+
weight=filter.raw_tensor,
|
|
2112
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2113
|
+
stride=strides,
|
|
2114
|
+
padding=padding_val,
|
|
2115
|
+
output_padding=output_padding or 0,
|
|
2116
|
+
)
|
|
2117
|
+
elif len(filter_size) == 2:
|
|
2118
|
+
out_raw = torch.nn.functional.conv_transpose2d(
|
|
2119
|
+
src_raw,
|
|
2120
|
+
weight=filter.raw_tensor,
|
|
2121
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2122
|
+
stride=strides,
|
|
2123
|
+
padding=padding_val,
|
|
2124
|
+
output_padding=output_padding or 0,
|
|
2125
|
+
)
|
|
2126
|
+
elif len(filter_size) == 3:
|
|
2127
|
+
out_raw = torch.nn.functional.conv_transpose3d(
|
|
2128
|
+
src_raw,
|
|
2129
|
+
weight=filter.raw_tensor,
|
|
2130
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2131
|
+
stride=strides,
|
|
2132
|
+
padding=padding_val,
|
|
2133
|
+
output_padding=output_padding or 0,
|
|
2134
|
+
)
|
|
2135
|
+
else:
|
|
2136
|
+
raise ValueError(f"invalid number of filter dims {filter_size}, expected 1, 2, or 3")
|
|
2137
|
+
if remove_padding:
|
|
2138
|
+
if isinstance(remove_padding, int):
|
|
2139
|
+
remove_padding = [remove_padding] * len(out_spatial_dims)
|
|
2140
|
+
assert len(remove_padding) == len(out_spatial_dims)
|
|
2141
|
+
slices = [slice(None)] * out_raw.ndim
|
|
2142
|
+
for i, pad in enumerate(remove_padding):
|
|
2143
|
+
if pad > 0:
|
|
2144
|
+
slices[2 + i] = slice(0, -pad)
|
|
2145
|
+
out_raw = out_raw[tuple(slices)]
|
|
2146
|
+
out = Tensor(
|
|
2147
|
+
"transposed_conv",
|
|
2148
|
+
dims=batch_dims + [out_dim] + list(out_spatial_dims),
|
|
2149
|
+
dtype=TorchBackend.get_dtype_name_raw(out_raw),
|
|
2150
|
+
)
|
|
2151
|
+
if len(batch_dims) == 1:
|
|
2152
|
+
out.raw_tensor = out_raw
|
|
2153
|
+
else:
|
|
2154
|
+
out.raw_tensor = torch.reshape(out_raw, [d.get_dim_value() for d in out.dims])
|
|
2155
|
+
out.feature_dim = out_dim
|
|
2156
|
+
return out, out_spatial_dims
|
|
2157
|
+
|
|
2036
2158
|
@staticmethod
|
|
2037
2159
|
def pool(
|
|
2038
2160
|
source: Tensor,
|
returnn/torch/frontend/bridge.py
CHANGED
|
@@ -136,6 +136,15 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
136
136
|
def _get_name(self):
|
|
137
137
|
return self._rf_module.__class__.__name__ + "[RF→PT]"
|
|
138
138
|
|
|
139
|
+
def __repr__(self) -> str:
|
|
140
|
+
"""
|
|
141
|
+
Return a custom repr for Sequential/ModuleList that compresses repeated module representations if possible,
|
|
142
|
+
otherwise fallback to default behavior.
|
|
143
|
+
"""
|
|
144
|
+
if _can_use_compact_repr(self):
|
|
145
|
+
return _repr_compact(self)
|
|
146
|
+
return super().__repr__()
|
|
147
|
+
|
|
139
148
|
@property
|
|
140
149
|
def rf_module(self) -> rf.Module:
|
|
141
150
|
"""RF module"""
|
|
@@ -193,3 +202,55 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
193
202
|
# See similar logic in torch.nn.Module._apply.
|
|
194
203
|
pt_param = torch.nn.Parameter(tensor, tensor.requires_grad)
|
|
195
204
|
rf_param.raw_tensor = pt_param
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _can_use_compact_repr(self: RFModuleAsPTModule) -> bool:
|
|
208
|
+
return list(self._modules.keys()) == [str(i) for i in range(len(self._modules))]
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _repr_compact(self: RFModuleAsPTModule) -> str:
|
|
212
|
+
"""
|
|
213
|
+
Return a custom repr for Sequential/ModuleList that compresses repeated module representations.
|
|
214
|
+
Code copied and adapted from torch.nn.ModuleList.__repr__.
|
|
215
|
+
"""
|
|
216
|
+
list_of_reprs = [repr(item) for item in self._modules.values()]
|
|
217
|
+
if len(list_of_reprs) == 0:
|
|
218
|
+
return self._get_name() + "()"
|
|
219
|
+
|
|
220
|
+
start_end_indices = [[0, 0]]
|
|
221
|
+
repeated_blocks = [list_of_reprs[0]]
|
|
222
|
+
for i, r in enumerate(list_of_reprs[1:], 1):
|
|
223
|
+
if r == repeated_blocks[-1]:
|
|
224
|
+
start_end_indices[-1][1] += 1
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
start_end_indices.append([i, i])
|
|
228
|
+
repeated_blocks.append(r)
|
|
229
|
+
|
|
230
|
+
lines = []
|
|
231
|
+
main_str = self._get_name() + "("
|
|
232
|
+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
|
233
|
+
local_repr = f"({start_id}): {b}" # default repr
|
|
234
|
+
|
|
235
|
+
if start_id != end_id:
|
|
236
|
+
n = end_id - start_id + 1
|
|
237
|
+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
|
|
238
|
+
|
|
239
|
+
local_repr = _add_indent(local_repr, 2)
|
|
240
|
+
lines.append(local_repr)
|
|
241
|
+
|
|
242
|
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
|
243
|
+
main_str += ")"
|
|
244
|
+
return main_str
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _add_indent(s_: str, num_spaces: int) -> str:
|
|
248
|
+
s = s_.split("\n")
|
|
249
|
+
# don't do anything for single-line stuff
|
|
250
|
+
if len(s) == 1:
|
|
251
|
+
return s_
|
|
252
|
+
first = s.pop(0)
|
|
253
|
+
s = [(num_spaces * " ") + line for line in s]
|
|
254
|
+
s = "\n".join(s)
|
|
255
|
+
s = first + "\n" + s
|
|
256
|
+
return s
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helpers to improve torch.compile on RF code.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Any, Iterable, List, Tuple
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from returnn.tensor import Tensor, Dim
|
|
10
|
+
|
|
11
|
+
# noinspection PyProtectedMember
|
|
12
|
+
from returnn.frontend import _native
|
|
13
|
+
|
|
14
|
+
_is_set_up = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup():
|
|
18
|
+
"""
|
|
19
|
+
Set up the torch.compile helpers for RF code, also including :class:`Tensor` and :class:`Dim`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
global _is_set_up
|
|
23
|
+
if _is_set_up:
|
|
24
|
+
return
|
|
25
|
+
_is_set_up = True # only try once
|
|
26
|
+
|
|
27
|
+
assert not _native.is_set_up(), "Call this setup() as early as possible."
|
|
28
|
+
_native.set_enabled(False)
|
|
29
|
+
|
|
30
|
+
# We have lots of dynamic shapes.
|
|
31
|
+
os.environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
|
|
32
|
+
|
|
33
|
+
# noinspection PyProtectedMember
|
|
34
|
+
from torch.utils._pytree import register_pytree_node
|
|
35
|
+
|
|
36
|
+
register_pytree_node(Tensor, _tensor_flatten, _tensor_unflatten)
|
|
37
|
+
register_pytree_node(Dim, _dim_flatten, _dim_unflatten)
|
|
38
|
+
|
|
39
|
+
Dim.get_dim_value = _dim_get_dim_value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _tensor_flatten(t: Tensor) -> Tuple[List[Any], Any]:
|
|
43
|
+
"""
|
|
44
|
+
Flatten the tensor for PyTree.
|
|
45
|
+
"""
|
|
46
|
+
return [t.raw_tensor, t.dims, t.sparse_dim], [
|
|
47
|
+
t.name,
|
|
48
|
+
t.dtype,
|
|
49
|
+
t.version,
|
|
50
|
+
t.feature_dim_axis_or_unspecified,
|
|
51
|
+
t.time_dim_axis_or_unspecified,
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _tensor_unflatten(values: Iterable[Any], metadata: Any) -> Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Unflatten the tensor from PyTree.
|
|
58
|
+
"""
|
|
59
|
+
raw_tensor, dims, sparse_dim = values
|
|
60
|
+
name, dtype, version, feature_dim_axis, time_dim_axis = metadata
|
|
61
|
+
return Tensor(
|
|
62
|
+
name=name,
|
|
63
|
+
dims=dims,
|
|
64
|
+
dtype=dtype,
|
|
65
|
+
sparse_dim=sparse_dim,
|
|
66
|
+
feature_dim_axis=feature_dim_axis,
|
|
67
|
+
time_dim_axis=time_dim_axis,
|
|
68
|
+
raw_tensor=raw_tensor,
|
|
69
|
+
version=version,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _dim_flatten(d: Dim) -> Tuple[List[Any], Any]:
|
|
74
|
+
"""
|
|
75
|
+
Flatten the dim for PyTree.
|
|
76
|
+
"""
|
|
77
|
+
return [d.dyn_size_ext], [d.name, d.dimension, d.size]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _dim_unflatten(values: Iterable[Any], metadata: Any) -> Dim:
|
|
81
|
+
"""
|
|
82
|
+
Unflatten the dim from PyTree.
|
|
83
|
+
"""
|
|
84
|
+
(dyn_size_ext,) = values
|
|
85
|
+
name, dimension, size = metadata
|
|
86
|
+
# TODO this creates a new instance... this is maybe wrong?
|
|
87
|
+
return Dim(name=name, dimension=dimension, size=size, dyn_size_ext=dyn_size_ext)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _dim_get_dim_value(self: Dim) -> int:
|
|
91
|
+
"""
|
|
92
|
+
Infers the dim this axis should have if unbroadcasted.
|
|
93
|
+
If `self.src_data` has a placeholder, will use the shape from there.
|
|
94
|
+
Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
|
|
95
|
+
|
|
96
|
+
:return: max(size or dyn_size)
|
|
97
|
+
"""
|
|
98
|
+
res = self.get_dim_value_tensor()
|
|
99
|
+
if isinstance(res, Tensor):
|
|
100
|
+
assert res.dims == ()
|
|
101
|
+
assert res.raw_tensor is not None
|
|
102
|
+
# Specifically PyTorch would then treat it as a SymInt in torch.compile,
|
|
103
|
+
# which is important to have for some torch functions (e.g. torch.tile and others).
|
|
104
|
+
return int(res.raw_tensor)
|
|
105
|
+
assert isinstance(res, int)
|
|
106
|
+
return res
|
|
@@ -71,7 +71,13 @@ def help_on_torch_exception(
|
|
|
71
71
|
if not count_frames:
|
|
72
72
|
exc_ext.append("(No module call frames.)")
|
|
73
73
|
|
|
74
|
-
if
|
|
74
|
+
if (
|
|
75
|
+
# KeyError formatting would be wrong, showing `KeyError: "enc_spatial_dim\n\nStep idx: 0\..."`
|
|
76
|
+
not isinstance(exc, KeyError)
|
|
77
|
+
and len(exc.args) == 1
|
|
78
|
+
and isinstance(exc.args[0], str)
|
|
79
|
+
and not always_direct_print
|
|
80
|
+
):
|
|
75
81
|
exc.args = ("\n".join([exc.args[0], ""] + exc_ext),)
|
|
76
82
|
else:
|
|
77
83
|
for msg in exc_ext:
|
returnn/util/basic.py
CHANGED
|
@@ -365,12 +365,9 @@ def get_checkpoint_filepattern(filepath):
|
|
|
365
365
|
:return: CheckpointLoader compatible filepattern
|
|
366
366
|
:rtype: str
|
|
367
367
|
"""
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
return filepath[: -len(".index")]
|
|
372
|
-
elif filepath.endswith(".pt"):
|
|
373
|
-
return filepath[: -len(".pt")]
|
|
368
|
+
for ext in [".meta", ".index", ".pt"]:
|
|
369
|
+
if filepath.endswith(ext):
|
|
370
|
+
return filepath[: -len(ext)]
|
|
374
371
|
return filepath
|
|
375
372
|
|
|
376
373
|
|
|
@@ -3819,6 +3816,8 @@ def should_write_to_disk(config):
|
|
|
3819
3816
|
return False
|
|
3820
3817
|
if config.is_true("dry_run"):
|
|
3821
3818
|
return False
|
|
3819
|
+
if config.is_true("torch_profile"):
|
|
3820
|
+
return False
|
|
3822
3821
|
return True
|
|
3823
3822
|
|
|
3824
3823
|
|