returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +130 -42
- returnn/datasets/meta.py +93 -43
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +48 -2
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +222 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +182 -172
- returnn/native_op.py +36 -31
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +8 -5
- returnn/tf/frontend_layers/_backend.py +7 -3
- returnn/tf/layers/basic.py +27 -40
- returnn/tf/native_op.py +27 -63
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +22 -197
- returnn/torch/engine.py +157 -6
- returnn/torch/frontend/_backend.py +280 -29
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +12 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- returnn/util/task_system.py +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/tf/util/basic.py
CHANGED
|
@@ -17,6 +17,7 @@ from tensorflow.python.client import device_lib
|
|
|
17
17
|
from tensorflow.python.ops import init_ops
|
|
18
18
|
from returnn.util import basic as util
|
|
19
19
|
from returnn.util.basic import NotSpecified, NativeCodeCompiler
|
|
20
|
+
from returnn.util.cuda_env import CudaEnv as _CudaEnvBase
|
|
20
21
|
from returnn.tensor import Tensor
|
|
21
22
|
import returnn.tf.compat as tf_compat
|
|
22
23
|
|
|
@@ -2768,206 +2769,15 @@ def get_tf_gpp_path():
|
|
|
2768
2769
|
return _tf_gpp_path
|
|
2769
2770
|
|
|
2770
2771
|
|
|
2771
|
-
class CudaEnv:
|
|
2772
|
+
class CudaEnv(_CudaEnvBase):
|
|
2772
2773
|
"""
|
|
2773
|
-
|
|
2774
|
-
Also path to ``nvcc``, the CUDA compiler.
|
|
2774
|
+
Helper class to get CUDA environment for TF.
|
|
2775
2775
|
"""
|
|
2776
2776
|
|
|
2777
|
-
|
|
2778
|
-
|
|
2779
|
-
|
|
2780
|
-
|
|
2781
|
-
from returnn.util.basic import to_bool
|
|
2782
|
-
|
|
2783
|
-
if to_bool(os.environ.get("DISABLE_CUDA", "0")):
|
|
2784
|
-
self.cuda_path = None
|
|
2785
|
-
if self.verbose_find_cuda:
|
|
2786
|
-
print("CUDA disabled via env DISABLE_CUDA.")
|
|
2787
|
-
else:
|
|
2788
|
-
self.cuda_path = self._find_cuda_path()
|
|
2789
|
-
if self.verbose_find_cuda:
|
|
2790
|
-
print("CUDA path:", self.cuda_path)
|
|
2791
|
-
self._max_compute_capability = None
|
|
2792
|
-
|
|
2793
|
-
@classmethod
|
|
2794
|
-
def _find_nvcc_in_path(cls):
|
|
2795
|
-
"""
|
|
2796
|
-
:return: yields full path to nvcc
|
|
2797
|
-
:rtype: list[str]
|
|
2798
|
-
"""
|
|
2799
|
-
for p in os.environ["PATH"].split(":"):
|
|
2800
|
-
pp = "%s/nvcc" % p
|
|
2801
|
-
if os.path.exists(pp):
|
|
2802
|
-
yield pp
|
|
2803
|
-
|
|
2804
|
-
@classmethod
|
|
2805
|
-
def _find_lib_in_ld_path(cls):
|
|
2806
|
-
"""
|
|
2807
|
-
:return: yields full path to libcudart.so
|
|
2808
|
-
:rtype: list[str]
|
|
2809
|
-
"""
|
|
2810
|
-
from returnn.util.basic import get_ld_paths
|
|
2811
|
-
|
|
2812
|
-
for p in get_ld_paths():
|
|
2813
|
-
pp = "%s/libcudart.so" % p
|
|
2814
|
-
if os.path.exists(pp):
|
|
2815
|
-
yield pp
|
|
2816
|
-
|
|
2817
|
-
@classmethod
|
|
2818
|
-
def _get_lib_dir_name(cls, base_path):
|
|
2819
|
-
"""
|
|
2820
|
-
:return: dir name in base path
|
|
2821
|
-
:rtype: str
|
|
2822
|
-
"""
|
|
2823
|
-
from returnn.util.basic import is_64bit_platform, get_ld_paths
|
|
2824
|
-
|
|
2825
|
-
for ld_path in get_ld_paths():
|
|
2826
|
-
# We also want to allow "lib/x86_64-linux-gnu" for "/usr".
|
|
2827
|
-
# However, this logic should not be triggered for incorrect cases.
|
|
2828
|
-
# E.g. base_path="/usr" would be the prefix for most LD paths.
|
|
2829
|
-
if ld_path.startswith(base_path + "/lib") and os.path.exists("%s/libcudart.so" % ld_path):
|
|
2830
|
-
return ld_path[len(base_path) + 1 :]
|
|
2831
|
-
if is_64bit_platform():
|
|
2832
|
-
return "lib64"
|
|
2833
|
-
return "lib"
|
|
2834
|
-
|
|
2835
|
-
@classmethod
|
|
2836
|
-
def _cuda_path_candidate_via_proc_map_libcudart(cls):
|
|
2837
|
-
from returnn.util.basic import find_libcudart_from_runtime
|
|
2838
|
-
|
|
2839
|
-
fn = find_libcudart_from_runtime()
|
|
2840
|
-
if cls.verbose_find_cuda:
|
|
2841
|
-
print("libcudart.so found from /proc/maps:", fn)
|
|
2842
|
-
if not fn:
|
|
2843
|
-
return None
|
|
2844
|
-
# fn is e.g. '/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61',
|
|
2845
|
-
# or maybe '/usr/local/cuda-8.0/lib64/libcudart.so'
|
|
2846
|
-
p = os.path.dirname(os.path.dirname(fn))
|
|
2847
|
-
while not cls._check_valid_cuda_path(p):
|
|
2848
|
-
p = os.path.dirname(p)
|
|
2849
|
-
if p in ["", "/"]:
|
|
2850
|
-
if cls.verbose_find_cuda:
|
|
2851
|
-
print(f"Loaded lib {fn} does not seem to be in valid CUDA path.")
|
|
2852
|
-
return None
|
|
2853
|
-
assert cls._check_valid_cuda_path(p)
|
|
2854
|
-
return p
|
|
2855
|
-
|
|
2856
|
-
@classmethod
|
|
2857
|
-
def _cuda_path_candidates(cls):
|
|
2858
|
-
p = cls._cuda_path_candidate_via_proc_map_libcudart()
|
|
2859
|
-
if p:
|
|
2860
|
-
yield p
|
|
2861
|
-
for p in cls._find_nvcc_in_path():
|
|
2862
|
-
# Expect p == "/usr/local/cuda-8.0/bin/nvcc" or so.
|
|
2863
|
-
postfix = "/bin/nvcc"
|
|
2864
|
-
if cls.verbose_find_cuda:
|
|
2865
|
-
print("found cuda nvcc (wanted postfix: %r): %s" % (postfix, p))
|
|
2866
|
-
if not p.endswith(postfix):
|
|
2867
|
-
continue
|
|
2868
|
-
yield p[: -len(postfix)] or "/"
|
|
2869
|
-
for p in cls._find_lib_in_ld_path():
|
|
2870
|
-
# Expect p == "/usr/local/cuda-8.0/lib64/libcudart.so" or so.
|
|
2871
|
-
d = "/".join(p.split("/")[:-2]) or "/" # Get "/usr/local/cuda-8.0".
|
|
2872
|
-
if cls.verbose_find_cuda:
|
|
2873
|
-
print("found cuda lib: %s (path %s)" % (p, d))
|
|
2874
|
-
yield d
|
|
2875
|
-
|
|
2876
|
-
@classmethod
|
|
2877
|
-
def _check_valid_cuda_path(cls, p):
|
|
2878
|
-
"""
|
|
2879
|
-
:param str p: path to CUDA, e.g. "/usr/local/cuda-8.0"
|
|
2880
|
-
:return: whether this is a valid CUDA path, i.e. we find all what we need
|
|
2881
|
-
:rtype: bool
|
|
2882
|
-
"""
|
|
2883
|
-
if cls.verbose_find_cuda:
|
|
2884
|
-
print("check valid CUDA path: %s" % p)
|
|
2885
|
-
if not os.path.exists("%s/bin/nvcc" % p):
|
|
2886
|
-
return False
|
|
2887
|
-
if not os.path.exists("%s/include/cuda.h" % p):
|
|
2888
|
-
return False
|
|
2889
|
-
if not os.path.exists("%s/%s/libcudart.so" % (p, cls._get_lib_dir_name(p))):
|
|
2890
|
-
return False
|
|
2891
|
-
return True
|
|
2892
|
-
|
|
2893
|
-
@classmethod
|
|
2894
|
-
def _find_cuda_path(cls):
|
|
2895
|
-
"""
|
|
2896
|
-
:return: base CUDA path if we find one, otherwise None
|
|
2897
|
-
:rtype: str|None
|
|
2898
|
-
"""
|
|
2899
|
-
for p in cls._cuda_path_candidates():
|
|
2900
|
-
if cls._check_valid_cuda_path(p):
|
|
2901
|
-
return p
|
|
2902
|
-
return None
|
|
2903
|
-
|
|
2904
|
-
def is_available(self):
|
|
2905
|
-
"""
|
|
2906
|
-
:rtype: bool
|
|
2907
|
-
"""
|
|
2908
|
-
return bool(self.cuda_path)
|
|
2909
|
-
|
|
2910
|
-
def get_max_compute_capability(self):
|
|
2911
|
-
"""
|
|
2912
|
-
:return: the highest compute capability supported by nvcc, or float("inf") if not known
|
|
2913
|
-
:rtype: float
|
|
2914
|
-
"""
|
|
2915
|
-
if self._max_compute_capability is None:
|
|
2916
|
-
cuda_occupancy_path = "%s/include/cuda_occupancy.h" % self.cuda_path
|
|
2917
|
-
if os.path.exists(cuda_occupancy_path):
|
|
2918
|
-
import re
|
|
2919
|
-
|
|
2920
|
-
major, minor = None, 0
|
|
2921
|
-
for line in open(cuda_occupancy_path).read().splitlines():
|
|
2922
|
-
m = re.match("^#define\\s+__CUDA_OCC_(MAJOR|MINOR)__\\s+([0-9]+)$", line)
|
|
2923
|
-
if m:
|
|
2924
|
-
s, v = m.groups()
|
|
2925
|
-
v = int(v)
|
|
2926
|
-
if s == "MAJOR":
|
|
2927
|
-
major = v
|
|
2928
|
-
else:
|
|
2929
|
-
minor = v
|
|
2930
|
-
if major:
|
|
2931
|
-
self._max_compute_capability = float(major) + float(minor) * 0.1
|
|
2932
|
-
if self._max_compute_capability is None:
|
|
2933
|
-
self._max_compute_capability = float("inf")
|
|
2934
|
-
return self._max_compute_capability
|
|
2935
|
-
|
|
2936
|
-
def get_compiler_opts(self):
|
|
2937
|
-
"""
|
|
2938
|
-
:rtype: list[str]
|
|
2939
|
-
"""
|
|
2940
|
-
return [
|
|
2941
|
-
"-ccbin",
|
|
2942
|
-
get_tf_gcc_path(),
|
|
2943
|
-
"-I",
|
|
2944
|
-
"%s/targets/x86_64-linux/include" % self.cuda_path,
|
|
2945
|
-
"-I",
|
|
2946
|
-
"%s/include" % self.cuda_path,
|
|
2947
|
-
"-L",
|
|
2948
|
-
"%s/%s" % (self.cuda_path, self._get_lib_dir_name(self.cuda_path)),
|
|
2949
|
-
"-x",
|
|
2950
|
-
"cu",
|
|
2951
|
-
"-v",
|
|
2952
|
-
]
|
|
2953
|
-
|
|
2954
|
-
def get_compiler_bin(self):
|
|
2955
|
-
"""
|
|
2956
|
-
:return: path
|
|
2957
|
-
:rtype: str
|
|
2958
|
-
"""
|
|
2959
|
-
assert self.cuda_path
|
|
2960
|
-
return "%s/bin/nvcc" % self.cuda_path
|
|
2961
|
-
|
|
2962
|
-
@classmethod
|
|
2963
|
-
def get_instance(cls):
|
|
2964
|
-
"""
|
|
2965
|
-
:rtype: CudaEnv
|
|
2966
|
-
"""
|
|
2967
|
-
if cls._instance is not None:
|
|
2968
|
-
return cls._instance
|
|
2969
|
-
cls._instance = cls()
|
|
2970
|
-
return cls._instance
|
|
2777
|
+
@staticmethod
|
|
2778
|
+
def get_cc_bin():
|
|
2779
|
+
"""compiler"""
|
|
2780
|
+
return get_tf_gcc_path()
|
|
2971
2781
|
|
|
2972
2782
|
|
|
2973
2783
|
class OpCodeCompiler(NativeCodeCompiler):
|
|
@@ -3020,6 +2830,21 @@ class OpCodeCompiler(NativeCodeCompiler):
|
|
|
3020
2830
|
ld_flags += tf.sysconfig.get_link_flags()
|
|
3021
2831
|
elif have_min_tf_version((1, 4)):
|
|
3022
2832
|
ld_flags += ["-L%s" % tf.sysconfig.get_lib(), "-ltensorflow_framework"]
|
|
2833
|
+
if have_min_tf_version((2, 20)):
|
|
2834
|
+
# TF 2.20 removed TF_MAJOR_VERSION and co from version.h,
|
|
2835
|
+
# and one is supposed to define these macros externally.
|
|
2836
|
+
# Also, release_version.h was added to define TF_VERSION_STRING based on this (if needed).
|
|
2837
|
+
# https://github.com/tensorflow/tensorflow/commit/c8f0e0620e5678d0f165a07e64114024a966ab7f
|
|
2838
|
+
major, minor, patch = tf.__version__.split(".", 2)
|
|
2839
|
+
patch, suffix = patch.split("-", 1) if "-" in patch else (patch, "")
|
|
2840
|
+
c_macro_defines.update(
|
|
2841
|
+
{
|
|
2842
|
+
"TF_MAJOR_VERSION": major,
|
|
2843
|
+
"TF_MINOR_VERSION": minor,
|
|
2844
|
+
"TF_PATCH_VERSION": patch,
|
|
2845
|
+
"TF_VERSION_SUFFIX": suffix,
|
|
2846
|
+
}
|
|
2847
|
+
)
|
|
3023
2848
|
use_cxx11_abi = getattr(getattr(tf, "sysconfig", tf), "CXX11_ABI_FLAG", getattr(tf, "CXX11_ABI_FLAG", False))
|
|
3024
2849
|
super(OpCodeCompiler, self).__init__(
|
|
3025
2850
|
include_paths=include_paths,
|
returnn/torch/engine.py
CHANGED
|
@@ -3,9 +3,11 @@ Main engine for PyTorch
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
+
|
|
6
7
|
from typing import Optional, Any, Union, Callable, Dict, Set
|
|
7
8
|
from contextlib import nullcontext, ExitStack, contextmanager
|
|
8
9
|
|
|
10
|
+
import sys
|
|
9
11
|
import gc
|
|
10
12
|
import os
|
|
11
13
|
import time
|
|
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
|
|
|
20
22
|
from torch.utils.data import DataLoader
|
|
21
23
|
from torch import autocast
|
|
22
24
|
from torch.cuda import amp
|
|
25
|
+
from torch.profiler import record_function
|
|
23
26
|
import numpy as np
|
|
24
27
|
|
|
25
28
|
import returnn
|
|
@@ -404,10 +407,14 @@ class Engine(EngineBase):
|
|
|
404
407
|
total_data_size_packed = NumbersDict()
|
|
405
408
|
total_data_size_padded = NumbersDict()
|
|
406
409
|
|
|
410
|
+
prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
|
|
411
|
+
if prof:
|
|
412
|
+
prof.__enter__()
|
|
413
|
+
|
|
407
414
|
report_prefix = f"ep {self.epoch} train"
|
|
408
415
|
try:
|
|
409
416
|
while True:
|
|
410
|
-
with torch.no_grad():
|
|
417
|
+
with torch.no_grad(), record_function("data_loading"):
|
|
411
418
|
extern_data_raw = next(data_iter, None)
|
|
412
419
|
|
|
413
420
|
step_begin_time = time.monotonic()
|
|
@@ -485,7 +492,8 @@ class Engine(EngineBase):
|
|
|
485
492
|
with (
|
|
486
493
|
self._ddp_pt_model.no_sync()
|
|
487
494
|
if (self._ddp_pt_model is not None and not perform_update_step)
|
|
488
|
-
else nullcontext()
|
|
495
|
+
else nullcontext(),
|
|
496
|
+
record_function("backward"),
|
|
489
497
|
):
|
|
490
498
|
if self._grad_scaler is not None:
|
|
491
499
|
self._grad_scaler.scale(total_loss.raw_tensor).backward()
|
|
@@ -500,7 +508,8 @@ class Engine(EngineBase):
|
|
|
500
508
|
|
|
501
509
|
# only update the weights when every gradient accumulation loop ends
|
|
502
510
|
if perform_update_step:
|
|
503
|
-
|
|
511
|
+
with record_function("optimizer_step"):
|
|
512
|
+
self._updater.step(grad_scaler=self._grad_scaler)
|
|
504
513
|
zero_grad_next_step = perform_update_step
|
|
505
514
|
|
|
506
515
|
if self._torch_distributed_ctx:
|
|
@@ -532,7 +541,7 @@ class Engine(EngineBase):
|
|
|
532
541
|
for key, val in eval_info.items():
|
|
533
542
|
self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
|
|
534
543
|
self._tensorboard_writer.add_scalar(
|
|
535
|
-
|
|
544
|
+
"train/learning_rate",
|
|
536
545
|
self._updater.get_effective_learning_rate(),
|
|
537
546
|
global_step=self.global_train_step,
|
|
538
547
|
)
|
|
@@ -582,10 +591,19 @@ class Engine(EngineBase):
|
|
|
582
591
|
self._updater.set_current_train_step(
|
|
583
592
|
global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
|
|
584
593
|
)
|
|
594
|
+
|
|
595
|
+
if prof:
|
|
596
|
+
prof.step()
|
|
597
|
+
|
|
585
598
|
except Exception as exc:
|
|
599
|
+
if prof:
|
|
600
|
+
prof.__exit__(type(exc), exc, exc.__traceback__)
|
|
586
601
|
help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
|
|
587
602
|
raise
|
|
588
603
|
|
|
604
|
+
if prof:
|
|
605
|
+
prof.__exit__(None, None, None)
|
|
606
|
+
|
|
589
607
|
elapsed = time.monotonic() - epoch_start_time
|
|
590
608
|
elapsed_computation_percentage = elapsed_computation_time / elapsed
|
|
591
609
|
total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
|
|
@@ -885,6 +903,7 @@ class Engine(EngineBase):
|
|
|
885
903
|
if self._default_float_dtype:
|
|
886
904
|
stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
|
|
887
905
|
stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
|
|
906
|
+
stack.enter_context(record_function("model_step"))
|
|
888
907
|
yield
|
|
889
908
|
|
|
890
909
|
def _run_step(
|
|
@@ -930,7 +949,7 @@ class Engine(EngineBase):
|
|
|
930
949
|
if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
|
|
931
950
|
filename = model_epoch_filename
|
|
932
951
|
print("Load model %s" % (filename,), file=log.v4)
|
|
933
|
-
checkpoint_state =
|
|
952
|
+
checkpoint_state = _torch_load(filename, device=self._device)
|
|
934
953
|
if epoch is None:
|
|
935
954
|
epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
|
|
936
955
|
step = checkpoint_state.get("step", 1)
|
|
@@ -1030,7 +1049,7 @@ class Engine(EngineBase):
|
|
|
1030
1049
|
print("(No relevant parameters matching.)", file=log.v3)
|
|
1031
1050
|
continue
|
|
1032
1051
|
print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
|
|
1033
|
-
preload_model_state =
|
|
1052
|
+
preload_model_state = _torch_load(opts["filename"], device=self._device)
|
|
1034
1053
|
if opts.get("checkpoint_key", "model") is not None:
|
|
1035
1054
|
# This can be used if an external checkpoint saves a checkpoint a different structure that just the
|
|
1036
1055
|
# model state dict. E.g., if a checkpoint is created using
|
|
@@ -1063,6 +1082,28 @@ class Engine(EngineBase):
|
|
|
1063
1082
|
preload_model_state_keys = set(preload_model_state.keys())
|
|
1064
1083
|
loaded_state_keys.update(preload_model_state.keys())
|
|
1065
1084
|
missing_keys.difference_update(preload_model_state.keys())
|
|
1085
|
+
|
|
1086
|
+
custom_missing_load_func = opts.get("custom_missing_load_func")
|
|
1087
|
+
if custom_missing_load_func:
|
|
1088
|
+
custom_missing_vars_map = {}
|
|
1089
|
+
for var_name in missing_keys_preload:
|
|
1090
|
+
var_shape = self._pt_model.state_dict()[var_name].shape
|
|
1091
|
+
var_val = custom_missing_load_func(
|
|
1092
|
+
name=var_name,
|
|
1093
|
+
shape=var_shape,
|
|
1094
|
+
preload_model_state=preload_model_state,
|
|
1095
|
+
**util.get_fwd_compat_kwargs(),
|
|
1096
|
+
)
|
|
1097
|
+
if var_val is not None:
|
|
1098
|
+
assert var_val.shape == var_shape
|
|
1099
|
+
custom_missing_vars_map[var_name] = var_val
|
|
1100
|
+
preload_model_state.update(custom_missing_vars_map)
|
|
1101
|
+
missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
|
|
1102
|
+
preload_model_state, strict=False
|
|
1103
|
+
)
|
|
1104
|
+
loaded_state_keys.update(preload_model_state.keys())
|
|
1105
|
+
missing_keys.difference_update(preload_model_state.keys())
|
|
1106
|
+
|
|
1066
1107
|
del preload_model_state
|
|
1067
1108
|
gc.collect()
|
|
1068
1109
|
|
|
@@ -1700,3 +1741,113 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
|
|
|
1700
1741
|
p=p,
|
|
1701
1742
|
).item()
|
|
1702
1743
|
)
|
|
1744
|
+
|
|
1745
|
+
|
|
1746
|
+
def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
|
|
1747
|
+
# Might resolve PtCheckpoint or Sisyphus Path objects or so.
|
|
1748
|
+
filename = os.fspath(filename)
|
|
1749
|
+
|
|
1750
|
+
if filename.endswith(".safetensors"):
|
|
1751
|
+
from safetensors.torch import load_file as safetensors_load
|
|
1752
|
+
|
|
1753
|
+
return safetensors_load(filename, device=device)
|
|
1754
|
+
|
|
1755
|
+
return torch.load(filename, map_location=device)
|
|
1756
|
+
|
|
1757
|
+
|
|
1758
|
+
class _TorchProfiler:
|
|
1759
|
+
def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
|
|
1760
|
+
self.profiler = profiler
|
|
1761
|
+
self.max_step = max_step
|
|
1762
|
+
self.entered = False
|
|
1763
|
+
|
|
1764
|
+
def __enter__(self):
|
|
1765
|
+
self.profiler.__enter__()
|
|
1766
|
+
self.entered = True
|
|
1767
|
+
|
|
1768
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1769
|
+
if not self.entered:
|
|
1770
|
+
return
|
|
1771
|
+
self.entered = False
|
|
1772
|
+
self.profiler.__exit__(exc_type, exc_val, exc_tb)
|
|
1773
|
+
|
|
1774
|
+
if exc_type is None:
|
|
1775
|
+
print(
|
|
1776
|
+
"Torch profiling finished, exporting Chrome trace to torch_profile.json,"
|
|
1777
|
+
" memory timeline to torch_memory_profile.html...",
|
|
1778
|
+
file=log.v2,
|
|
1779
|
+
)
|
|
1780
|
+
self.profiler.export_chrome_trace("torch_profile.json")
|
|
1781
|
+
self.profiler.export_memory_timeline("torch_memory_profile.html")
|
|
1782
|
+
|
|
1783
|
+
print("Exiting program after Torch profiling.", file=log.v2)
|
|
1784
|
+
sys.exit(0)
|
|
1785
|
+
|
|
1786
|
+
def step(self):
|
|
1787
|
+
"""step"""
|
|
1788
|
+
self.profiler.step()
|
|
1789
|
+
if self.max_step is not None and self.profiler.step_num > self.max_step:
|
|
1790
|
+
print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
|
|
1791
|
+
self.profiler.stop()
|
|
1792
|
+
self.__exit__(None, None, None)
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
def _opt_torch_profiler_from_opts(
|
|
1796
|
+
opts: Union[None, int, bool, str, Dict[str, Any]],
|
|
1797
|
+
) -> Optional[_TorchProfiler]:
|
|
1798
|
+
if isinstance(opts, str):
|
|
1799
|
+
from returnn.util.basic import to_bool
|
|
1800
|
+
|
|
1801
|
+
opts = to_bool(opts)
|
|
1802
|
+
|
|
1803
|
+
if opts is None:
|
|
1804
|
+
return None
|
|
1805
|
+
elif isinstance(opts, (bool, int)):
|
|
1806
|
+
if not opts:
|
|
1807
|
+
return None
|
|
1808
|
+
opts = {}
|
|
1809
|
+
elif isinstance(opts, dict):
|
|
1810
|
+
opts = opts.copy()
|
|
1811
|
+
else:
|
|
1812
|
+
raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
|
|
1813
|
+
|
|
1814
|
+
from torch.profiler import profile, ProfilerActivity, schedule
|
|
1815
|
+
|
|
1816
|
+
print("Using Torch profiler...", file=log.v2)
|
|
1817
|
+
|
|
1818
|
+
prof_max_step = None
|
|
1819
|
+
|
|
1820
|
+
if "activities" not in opts:
|
|
1821
|
+
activities = [ProfilerActivity.CPU]
|
|
1822
|
+
if torch.cuda.is_available():
|
|
1823
|
+
activities += [ProfilerActivity.CUDA]
|
|
1824
|
+
elif torch.xpu.is_available():
|
|
1825
|
+
activities += [ProfilerActivity.XPU]
|
|
1826
|
+
opts["activities"] = activities
|
|
1827
|
+
|
|
1828
|
+
opts.setdefault("profile_memory", True)
|
|
1829
|
+
opts.setdefault("record_shapes", True)
|
|
1830
|
+
opts.setdefault("with_stack", True)
|
|
1831
|
+
opts.setdefault("with_flops", True)
|
|
1832
|
+
# Note: active*repeat are the steps we actually profile.
|
|
1833
|
+
opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
|
|
1834
|
+
|
|
1835
|
+
if isinstance(opts["schedule"], dict):
|
|
1836
|
+
schedule_opts: Dict[str, Any] = opts["schedule"]
|
|
1837
|
+
schedule_opts = schedule_opts.copy()
|
|
1838
|
+
schedule_opts.setdefault("repeat", 0)
|
|
1839
|
+
schedule_opts.setdefault("skip_first", 0)
|
|
1840
|
+
schedule_opts.setdefault("skip_first_wait", 0)
|
|
1841
|
+
opts["schedule"] = schedule(**schedule_opts)
|
|
1842
|
+
|
|
1843
|
+
if schedule_opts["repeat"] > 0:
|
|
1844
|
+
prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
|
|
1845
|
+
"repeat"
|
|
1846
|
+
]
|
|
1847
|
+
prof_max_step += schedule_opts["skip_first"]
|
|
1848
|
+
if schedule_opts["skip_first_wait"] != 0:
|
|
1849
|
+
prof_max_step -= schedule_opts["wait"]
|
|
1850
|
+
print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
|
|
1851
|
+
|
|
1852
|
+
prof = profile(**opts)
|
|
1853
|
+
return _TorchProfiler(prof, prof_max_step)
|