returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +130 -42
  5. returnn/datasets/meta.py +93 -43
  6. returnn/datasets/postprocessing.py +597 -108
  7. returnn/datasets/util/vocabulary.py +90 -0
  8. returnn/frontend/__init__.py +1 -0
  9. returnn/frontend/_backend.py +41 -0
  10. returnn/frontend/_native/__init__.py +22 -0
  11. returnn/frontend/_numpy_backend.py +7 -0
  12. returnn/frontend/_utils.py +1 -1
  13. returnn/frontend/array_.py +48 -2
  14. returnn/frontend/assert_.py +35 -0
  15. returnn/frontend/attention.py +54 -20
  16. returnn/frontend/conv.py +273 -54
  17. returnn/frontend/device.py +14 -1
  18. returnn/frontend/encoder/conformer.py +20 -0
  19. returnn/frontend/encoder/transformer.py +2 -0
  20. returnn/frontend/loss.py +222 -3
  21. returnn/frontend/math_.py +54 -14
  22. returnn/native_op.cpp +182 -172
  23. returnn/native_op.py +36 -31
  24. returnn/sprint/cache.py +12 -13
  25. returnn/tensor/_dim_extra.py +7 -7
  26. returnn/tensor/_tensor_extra.py +10 -10
  27. returnn/tensor/utils.py +8 -5
  28. returnn/tf/frontend_layers/_backend.py +7 -3
  29. returnn/tf/layers/basic.py +27 -40
  30. returnn/tf/native_op.py +27 -63
  31. returnn/tf/network.py +1 -1
  32. returnn/tf/util/basic.py +22 -197
  33. returnn/torch/engine.py +157 -6
  34. returnn/torch/frontend/_backend.py +280 -29
  35. returnn/torch/frontend/bridge.py +61 -0
  36. returnn/torch/frontend/compile_helper.py +106 -0
  37. returnn/torch/util/array_.py +30 -0
  38. returnn/torch/util/assert_.py +122 -0
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/torch/util/native_op.py +885 -0
  41. returnn/torch/util/native_op_code_compiler.py +308 -0
  42. returnn/util/basic.py +6 -7
  43. returnn/util/better_exchook.py +4 -0
  44. returnn/util/cuda_env.py +332 -0
  45. returnn/util/debug.py +12 -2
  46. returnn/util/file_cache.py +15 -1
  47. returnn/util/fsa.py +17 -13
  48. returnn/util/native_code_compiler.py +104 -47
  49. returnn/util/task_system.py +1 -1
  50. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
  51. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
  52. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  53. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  54. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/tf/util/basic.py CHANGED
@@ -17,6 +17,7 @@ from tensorflow.python.client import device_lib
17
17
  from tensorflow.python.ops import init_ops
18
18
  from returnn.util import basic as util
19
19
  from returnn.util.basic import NotSpecified, NativeCodeCompiler
20
+ from returnn.util.cuda_env import CudaEnv as _CudaEnvBase
20
21
  from returnn.tensor import Tensor
21
22
  import returnn.tf.compat as tf_compat
22
23
 
@@ -2768,206 +2769,15 @@ def get_tf_gpp_path():
2768
2769
  return _tf_gpp_path
2769
2770
 
2770
2771
 
2771
- class CudaEnv:
2772
+ class CudaEnv(_CudaEnvBase):
2772
2773
  """
2773
- Information about the Nvidia CUDA environment, and library.
2774
- Also path to ``nvcc``, the CUDA compiler.
2774
+ Helper class to get CUDA environment for TF.
2775
2775
  """
2776
2776
 
2777
- _instance = None
2778
- verbose_find_cuda = False
2779
-
2780
- def __init__(self):
2781
- from returnn.util.basic import to_bool
2782
-
2783
- if to_bool(os.environ.get("DISABLE_CUDA", "0")):
2784
- self.cuda_path = None
2785
- if self.verbose_find_cuda:
2786
- print("CUDA disabled via env DISABLE_CUDA.")
2787
- else:
2788
- self.cuda_path = self._find_cuda_path()
2789
- if self.verbose_find_cuda:
2790
- print("CUDA path:", self.cuda_path)
2791
- self._max_compute_capability = None
2792
-
2793
- @classmethod
2794
- def _find_nvcc_in_path(cls):
2795
- """
2796
- :return: yields full path to nvcc
2797
- :rtype: list[str]
2798
- """
2799
- for p in os.environ["PATH"].split(":"):
2800
- pp = "%s/nvcc" % p
2801
- if os.path.exists(pp):
2802
- yield pp
2803
-
2804
- @classmethod
2805
- def _find_lib_in_ld_path(cls):
2806
- """
2807
- :return: yields full path to libcudart.so
2808
- :rtype: list[str]
2809
- """
2810
- from returnn.util.basic import get_ld_paths
2811
-
2812
- for p in get_ld_paths():
2813
- pp = "%s/libcudart.so" % p
2814
- if os.path.exists(pp):
2815
- yield pp
2816
-
2817
- @classmethod
2818
- def _get_lib_dir_name(cls, base_path):
2819
- """
2820
- :return: dir name in base path
2821
- :rtype: str
2822
- """
2823
- from returnn.util.basic import is_64bit_platform, get_ld_paths
2824
-
2825
- for ld_path in get_ld_paths():
2826
- # We also want to allow "lib/x86_64-linux-gnu" for "/usr".
2827
- # However, this logic should not be triggered for incorrect cases.
2828
- # E.g. base_path="/usr" would be the prefix for most LD paths.
2829
- if ld_path.startswith(base_path + "/lib") and os.path.exists("%s/libcudart.so" % ld_path):
2830
- return ld_path[len(base_path) + 1 :]
2831
- if is_64bit_platform():
2832
- return "lib64"
2833
- return "lib"
2834
-
2835
- @classmethod
2836
- def _cuda_path_candidate_via_proc_map_libcudart(cls):
2837
- from returnn.util.basic import find_libcudart_from_runtime
2838
-
2839
- fn = find_libcudart_from_runtime()
2840
- if cls.verbose_find_cuda:
2841
- print("libcudart.so found from /proc/maps:", fn)
2842
- if not fn:
2843
- return None
2844
- # fn is e.g. '/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61',
2845
- # or maybe '/usr/local/cuda-8.0/lib64/libcudart.so'
2846
- p = os.path.dirname(os.path.dirname(fn))
2847
- while not cls._check_valid_cuda_path(p):
2848
- p = os.path.dirname(p)
2849
- if p in ["", "/"]:
2850
- if cls.verbose_find_cuda:
2851
- print(f"Loaded lib {fn} does not seem to be in valid CUDA path.")
2852
- return None
2853
- assert cls._check_valid_cuda_path(p)
2854
- return p
2855
-
2856
- @classmethod
2857
- def _cuda_path_candidates(cls):
2858
- p = cls._cuda_path_candidate_via_proc_map_libcudart()
2859
- if p:
2860
- yield p
2861
- for p in cls._find_nvcc_in_path():
2862
- # Expect p == "/usr/local/cuda-8.0/bin/nvcc" or so.
2863
- postfix = "/bin/nvcc"
2864
- if cls.verbose_find_cuda:
2865
- print("found cuda nvcc (wanted postfix: %r): %s" % (postfix, p))
2866
- if not p.endswith(postfix):
2867
- continue
2868
- yield p[: -len(postfix)] or "/"
2869
- for p in cls._find_lib_in_ld_path():
2870
- # Expect p == "/usr/local/cuda-8.0/lib64/libcudart.so" or so.
2871
- d = "/".join(p.split("/")[:-2]) or "/" # Get "/usr/local/cuda-8.0".
2872
- if cls.verbose_find_cuda:
2873
- print("found cuda lib: %s (path %s)" % (p, d))
2874
- yield d
2875
-
2876
- @classmethod
2877
- def _check_valid_cuda_path(cls, p):
2878
- """
2879
- :param str p: path to CUDA, e.g. "/usr/local/cuda-8.0"
2880
- :return: whether this is a valid CUDA path, i.e. we find all what we need
2881
- :rtype: bool
2882
- """
2883
- if cls.verbose_find_cuda:
2884
- print("check valid CUDA path: %s" % p)
2885
- if not os.path.exists("%s/bin/nvcc" % p):
2886
- return False
2887
- if not os.path.exists("%s/include/cuda.h" % p):
2888
- return False
2889
- if not os.path.exists("%s/%s/libcudart.so" % (p, cls._get_lib_dir_name(p))):
2890
- return False
2891
- return True
2892
-
2893
- @classmethod
2894
- def _find_cuda_path(cls):
2895
- """
2896
- :return: base CUDA path if we find one, otherwise None
2897
- :rtype: str|None
2898
- """
2899
- for p in cls._cuda_path_candidates():
2900
- if cls._check_valid_cuda_path(p):
2901
- return p
2902
- return None
2903
-
2904
- def is_available(self):
2905
- """
2906
- :rtype: bool
2907
- """
2908
- return bool(self.cuda_path)
2909
-
2910
- def get_max_compute_capability(self):
2911
- """
2912
- :return: the highest compute capability supported by nvcc, or float("inf") if not known
2913
- :rtype: float
2914
- """
2915
- if self._max_compute_capability is None:
2916
- cuda_occupancy_path = "%s/include/cuda_occupancy.h" % self.cuda_path
2917
- if os.path.exists(cuda_occupancy_path):
2918
- import re
2919
-
2920
- major, minor = None, 0
2921
- for line in open(cuda_occupancy_path).read().splitlines():
2922
- m = re.match("^#define\\s+__CUDA_OCC_(MAJOR|MINOR)__\\s+([0-9]+)$", line)
2923
- if m:
2924
- s, v = m.groups()
2925
- v = int(v)
2926
- if s == "MAJOR":
2927
- major = v
2928
- else:
2929
- minor = v
2930
- if major:
2931
- self._max_compute_capability = float(major) + float(minor) * 0.1
2932
- if self._max_compute_capability is None:
2933
- self._max_compute_capability = float("inf")
2934
- return self._max_compute_capability
2935
-
2936
- def get_compiler_opts(self):
2937
- """
2938
- :rtype: list[str]
2939
- """
2940
- return [
2941
- "-ccbin",
2942
- get_tf_gcc_path(),
2943
- "-I",
2944
- "%s/targets/x86_64-linux/include" % self.cuda_path,
2945
- "-I",
2946
- "%s/include" % self.cuda_path,
2947
- "-L",
2948
- "%s/%s" % (self.cuda_path, self._get_lib_dir_name(self.cuda_path)),
2949
- "-x",
2950
- "cu",
2951
- "-v",
2952
- ]
2953
-
2954
- def get_compiler_bin(self):
2955
- """
2956
- :return: path
2957
- :rtype: str
2958
- """
2959
- assert self.cuda_path
2960
- return "%s/bin/nvcc" % self.cuda_path
2961
-
2962
- @classmethod
2963
- def get_instance(cls):
2964
- """
2965
- :rtype: CudaEnv
2966
- """
2967
- if cls._instance is not None:
2968
- return cls._instance
2969
- cls._instance = cls()
2970
- return cls._instance
2777
+ @staticmethod
2778
+ def get_cc_bin():
2779
+ """compiler"""
2780
+ return get_tf_gcc_path()
2971
2781
 
2972
2782
 
2973
2783
  class OpCodeCompiler(NativeCodeCompiler):
@@ -3020,6 +2830,21 @@ class OpCodeCompiler(NativeCodeCompiler):
3020
2830
  ld_flags += tf.sysconfig.get_link_flags()
3021
2831
  elif have_min_tf_version((1, 4)):
3022
2832
  ld_flags += ["-L%s" % tf.sysconfig.get_lib(), "-ltensorflow_framework"]
2833
+ if have_min_tf_version((2, 20)):
2834
+ # TF 2.20 removed TF_MAJOR_VERSION and co from version.h,
2835
+ # and one is supposed to define these macros externally.
2836
+ # Also, release_version.h was added to define TF_VERSION_STRING based on this (if needed).
2837
+ # https://github.com/tensorflow/tensorflow/commit/c8f0e0620e5678d0f165a07e64114024a966ab7f
2838
+ major, minor, patch = tf.__version__.split(".", 2)
2839
+ patch, suffix = patch.split("-", 1) if "-" in patch else (patch, "")
2840
+ c_macro_defines.update(
2841
+ {
2842
+ "TF_MAJOR_VERSION": major,
2843
+ "TF_MINOR_VERSION": minor,
2844
+ "TF_PATCH_VERSION": patch,
2845
+ "TF_VERSION_SUFFIX": suffix,
2846
+ }
2847
+ )
3023
2848
  use_cxx11_abi = getattr(getattr(tf, "sysconfig", tf), "CXX11_ABI_FLAG", getattr(tf, "CXX11_ABI_FLAG", False))
3024
2849
  super(OpCodeCompiler, self).__init__(
3025
2850
  include_paths=include_paths,
returnn/torch/engine.py CHANGED
@@ -3,9 +3,11 @@ Main engine for PyTorch
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+
6
7
  from typing import Optional, Any, Union, Callable, Dict, Set
7
8
  from contextlib import nullcontext, ExitStack, contextmanager
8
9
 
10
+ import sys
9
11
  import gc
10
12
  import os
11
13
  import time
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
20
22
  from torch.utils.data import DataLoader
21
23
  from torch import autocast
22
24
  from torch.cuda import amp
25
+ from torch.profiler import record_function
23
26
  import numpy as np
24
27
 
25
28
  import returnn
@@ -404,10 +407,14 @@ class Engine(EngineBase):
404
407
  total_data_size_packed = NumbersDict()
405
408
  total_data_size_padded = NumbersDict()
406
409
 
410
+ prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
411
+ if prof:
412
+ prof.__enter__()
413
+
407
414
  report_prefix = f"ep {self.epoch} train"
408
415
  try:
409
416
  while True:
410
- with torch.no_grad():
417
+ with torch.no_grad(), record_function("data_loading"):
411
418
  extern_data_raw = next(data_iter, None)
412
419
 
413
420
  step_begin_time = time.monotonic()
@@ -485,7 +492,8 @@ class Engine(EngineBase):
485
492
  with (
486
493
  self._ddp_pt_model.no_sync()
487
494
  if (self._ddp_pt_model is not None and not perform_update_step)
488
- else nullcontext()
495
+ else nullcontext(),
496
+ record_function("backward"),
489
497
  ):
490
498
  if self._grad_scaler is not None:
491
499
  self._grad_scaler.scale(total_loss.raw_tensor).backward()
@@ -500,7 +508,8 @@ class Engine(EngineBase):
500
508
 
501
509
  # only update the weights when every gradient accumulation loop ends
502
510
  if perform_update_step:
503
- self._updater.step(grad_scaler=self._grad_scaler)
511
+ with record_function("optimizer_step"):
512
+ self._updater.step(grad_scaler=self._grad_scaler)
504
513
  zero_grad_next_step = perform_update_step
505
514
 
506
515
  if self._torch_distributed_ctx:
@@ -532,7 +541,7 @@ class Engine(EngineBase):
532
541
  for key, val in eval_info.items():
533
542
  self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
534
543
  self._tensorboard_writer.add_scalar(
535
- f"train/learning_rate",
544
+ "train/learning_rate",
536
545
  self._updater.get_effective_learning_rate(),
537
546
  global_step=self.global_train_step,
538
547
  )
@@ -582,10 +591,19 @@ class Engine(EngineBase):
582
591
  self._updater.set_current_train_step(
583
592
  global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
584
593
  )
594
+
595
+ if prof:
596
+ prof.step()
597
+
585
598
  except Exception as exc:
599
+ if prof:
600
+ prof.__exit__(type(exc), exc, exc.__traceback__)
586
601
  help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
587
602
  raise
588
603
 
604
+ if prof:
605
+ prof.__exit__(None, None, None)
606
+
589
607
  elapsed = time.monotonic() - epoch_start_time
590
608
  elapsed_computation_percentage = elapsed_computation_time / elapsed
591
609
  total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
@@ -885,6 +903,7 @@ class Engine(EngineBase):
885
903
  if self._default_float_dtype:
886
904
  stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
887
905
  stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
906
+ stack.enter_context(record_function("model_step"))
888
907
  yield
889
908
 
890
909
  def _run_step(
@@ -930,7 +949,7 @@ class Engine(EngineBase):
930
949
  if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
931
950
  filename = model_epoch_filename
932
951
  print("Load model %s" % (filename,), file=log.v4)
933
- checkpoint_state = torch.load(filename, map_location=self._device)
952
+ checkpoint_state = _torch_load(filename, device=self._device)
934
953
  if epoch is None:
935
954
  epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
936
955
  step = checkpoint_state.get("step", 1)
@@ -1030,7 +1049,7 @@ class Engine(EngineBase):
1030
1049
  print("(No relevant parameters matching.)", file=log.v3)
1031
1050
  continue
1032
1051
  print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
1033
- preload_model_state = torch.load(opts["filename"], map_location=self._device)
1052
+ preload_model_state = _torch_load(opts["filename"], device=self._device)
1034
1053
  if opts.get("checkpoint_key", "model") is not None:
1035
1054
  # This can be used if an external checkpoint saves a checkpoint a different structure that just the
1036
1055
  # model state dict. E.g., if a checkpoint is created using
@@ -1063,6 +1082,28 @@ class Engine(EngineBase):
1063
1082
  preload_model_state_keys = set(preload_model_state.keys())
1064
1083
  loaded_state_keys.update(preload_model_state.keys())
1065
1084
  missing_keys.difference_update(preload_model_state.keys())
1085
+
1086
+ custom_missing_load_func = opts.get("custom_missing_load_func")
1087
+ if custom_missing_load_func:
1088
+ custom_missing_vars_map = {}
1089
+ for var_name in missing_keys_preload:
1090
+ var_shape = self._pt_model.state_dict()[var_name].shape
1091
+ var_val = custom_missing_load_func(
1092
+ name=var_name,
1093
+ shape=var_shape,
1094
+ preload_model_state=preload_model_state,
1095
+ **util.get_fwd_compat_kwargs(),
1096
+ )
1097
+ if var_val is not None:
1098
+ assert var_val.shape == var_shape
1099
+ custom_missing_vars_map[var_name] = var_val
1100
+ preload_model_state.update(custom_missing_vars_map)
1101
+ missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
1102
+ preload_model_state, strict=False
1103
+ )
1104
+ loaded_state_keys.update(preload_model_state.keys())
1105
+ missing_keys.difference_update(preload_model_state.keys())
1106
+
1066
1107
  del preload_model_state
1067
1108
  gc.collect()
1068
1109
 
@@ -1700,3 +1741,113 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1700
1741
  p=p,
1701
1742
  ).item()
1702
1743
  )
1744
+
1745
+
1746
+ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
1747
+ # Might resolve PtCheckpoint or Sisyphus Path objects or so.
1748
+ filename = os.fspath(filename)
1749
+
1750
+ if filename.endswith(".safetensors"):
1751
+ from safetensors.torch import load_file as safetensors_load
1752
+
1753
+ return safetensors_load(filename, device=device)
1754
+
1755
+ return torch.load(filename, map_location=device)
1756
+
1757
+
1758
+ class _TorchProfiler:
1759
+ def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
1760
+ self.profiler = profiler
1761
+ self.max_step = max_step
1762
+ self.entered = False
1763
+
1764
+ def __enter__(self):
1765
+ self.profiler.__enter__()
1766
+ self.entered = True
1767
+
1768
+ def __exit__(self, exc_type, exc_val, exc_tb):
1769
+ if not self.entered:
1770
+ return
1771
+ self.entered = False
1772
+ self.profiler.__exit__(exc_type, exc_val, exc_tb)
1773
+
1774
+ if exc_type is None:
1775
+ print(
1776
+ "Torch profiling finished, exporting Chrome trace to torch_profile.json,"
1777
+ " memory timeline to torch_memory_profile.html...",
1778
+ file=log.v2,
1779
+ )
1780
+ self.profiler.export_chrome_trace("torch_profile.json")
1781
+ self.profiler.export_memory_timeline("torch_memory_profile.html")
1782
+
1783
+ print("Exiting program after Torch profiling.", file=log.v2)
1784
+ sys.exit(0)
1785
+
1786
+ def step(self):
1787
+ """step"""
1788
+ self.profiler.step()
1789
+ if self.max_step is not None and self.profiler.step_num > self.max_step:
1790
+ print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
1791
+ self.profiler.stop()
1792
+ self.__exit__(None, None, None)
1793
+
1794
+
1795
+ def _opt_torch_profiler_from_opts(
1796
+ opts: Union[None, int, bool, str, Dict[str, Any]],
1797
+ ) -> Optional[_TorchProfiler]:
1798
+ if isinstance(opts, str):
1799
+ from returnn.util.basic import to_bool
1800
+
1801
+ opts = to_bool(opts)
1802
+
1803
+ if opts is None:
1804
+ return None
1805
+ elif isinstance(opts, (bool, int)):
1806
+ if not opts:
1807
+ return None
1808
+ opts = {}
1809
+ elif isinstance(opts, dict):
1810
+ opts = opts.copy()
1811
+ else:
1812
+ raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
1813
+
1814
+ from torch.profiler import profile, ProfilerActivity, schedule
1815
+
1816
+ print("Using Torch profiler...", file=log.v2)
1817
+
1818
+ prof_max_step = None
1819
+
1820
+ if "activities" not in opts:
1821
+ activities = [ProfilerActivity.CPU]
1822
+ if torch.cuda.is_available():
1823
+ activities += [ProfilerActivity.CUDA]
1824
+ elif torch.xpu.is_available():
1825
+ activities += [ProfilerActivity.XPU]
1826
+ opts["activities"] = activities
1827
+
1828
+ opts.setdefault("profile_memory", True)
1829
+ opts.setdefault("record_shapes", True)
1830
+ opts.setdefault("with_stack", True)
1831
+ opts.setdefault("with_flops", True)
1832
+ # Note: active*repeat are the steps we actually profile.
1833
+ opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
1834
+
1835
+ if isinstance(opts["schedule"], dict):
1836
+ schedule_opts: Dict[str, Any] = opts["schedule"]
1837
+ schedule_opts = schedule_opts.copy()
1838
+ schedule_opts.setdefault("repeat", 0)
1839
+ schedule_opts.setdefault("skip_first", 0)
1840
+ schedule_opts.setdefault("skip_first_wait", 0)
1841
+ opts["schedule"] = schedule(**schedule_opts)
1842
+
1843
+ if schedule_opts["repeat"] > 0:
1844
+ prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
1845
+ "repeat"
1846
+ ]
1847
+ prof_max_step += schedule_opts["skip_first"]
1848
+ if schedule_opts["skip_first_wait"] != 0:
1849
+ prof_max_step -= schedule_opts["wait"]
1850
+ print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
1851
+
1852
+ prof = profile(**opts)
1853
+ return _TorchProfiler(prof, prof_max_step)