returnn 1.20260105.192646__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. returnn/PKG-INFO +1 -1
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +110 -42
  5. returnn/frontend/__init__.py +1 -0
  6. returnn/frontend/_backend.py +41 -0
  7. returnn/frontend/_native/__init__.py +22 -0
  8. returnn/frontend/_numpy_backend.py +7 -0
  9. returnn/frontend/_utils.py +1 -1
  10. returnn/frontend/array_.py +6 -5
  11. returnn/frontend/assert_.py +35 -0
  12. returnn/frontend/device.py +14 -1
  13. returnn/frontend/encoder/conformer.py +19 -0
  14. returnn/frontend/loss.py +183 -3
  15. returnn/frontend/math_.py +54 -14
  16. returnn/native_op.cpp +104 -174
  17. returnn/native_op.py +36 -31
  18. returnn/tensor/_dim_extra.py +7 -7
  19. returnn/tensor/_tensor_extra.py +10 -10
  20. returnn/tensor/utils.py +1 -1
  21. returnn/tf/frontend_layers/_backend.py +3 -1
  22. returnn/tf/layers/basic.py +13 -2
  23. returnn/tf/native_op.py +16 -5
  24. returnn/tf/util/basic.py +7 -201
  25. returnn/torch/engine.py +120 -3
  26. returnn/torch/frontend/_backend.py +166 -22
  27. returnn/torch/frontend/bridge.py +61 -0
  28. returnn/torch/frontend/compile_helper.py +106 -0
  29. returnn/torch/util/array_.py +30 -0
  30. returnn/torch/util/assert_.py +122 -0
  31. returnn/torch/util/native_op.py +885 -0
  32. returnn/torch/util/native_op_code_compiler.py +308 -0
  33. returnn/util/basic.py +3 -1
  34. returnn/util/cuda_env.py +332 -0
  35. returnn/util/debug.py +1 -0
  36. returnn/util/fsa.py +17 -13
  37. returnn/util/native_code_compiler.py +104 -47
  38. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +1 -1
  39. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +42 -36
  40. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  41. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  42. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
@@ -2741,7 +2741,7 @@ class BooleanMaskLayer(LayerBase):
2741
2741
  tensor = self.sources[0].output
2742
2742
  remaining_dims = [d for d in tensor.dims if d not in dims]
2743
2743
  tensor_templ = tensor.copy_template_new_dim_tags(tuple(dims) + tuple(remaining_dims))
2744
- tensor = tensor.copy_compatible_to(tensor_templ, add_dims=False)
2744
+ tensor = tensor.copy_compatible_to(tensor_templ, unbroadcast=True)
2745
2745
  mask_templ = mask.output.copy_template_new_dim_tags(new_dim_tags=tuple(dims))
2746
2746
  mask_ = mask.output.copy_compatible_to(mask_templ, add_dims=False)
2747
2747
  self.output.raw_tensor = tf.boolean_mask(tensor.raw_tensor, mask=mask_.raw_tensor)
@@ -11538,13 +11538,23 @@ class CtcLossLayer(LayerBase):
11538
11538
  layer_class = "ctc_loss"
11539
11539
  recurrent = True # order matters
11540
11540
 
11541
- def __init__(self, logits, targets, logits_normalized=False, blank_index=-1, max_approx=False, **kwargs):
11541
+ def __init__(
11542
+ self,
11543
+ logits,
11544
+ targets,
11545
+ logits_normalized=False,
11546
+ blank_index=-1,
11547
+ max_approx=False,
11548
+ label_loop: bool = True,
11549
+ **kwargs,
11550
+ ):
11542
11551
  """
11543
11552
  :param LayerBase logits: (before softmax). shape [B,T,D]
11544
11553
  :param LayerBase targets: sparse. shape [B,T]
11545
11554
  :param bool logits_normalized: whether the logits are already normalized (e.g. via log-softmax)
11546
11555
  :param int blank_index: vocab index of the blank symbol
11547
11556
  :param bool max_approx: if True, use max instead of sum over alignments (max approx, Viterbi)
11557
+ :param label_loop:
11548
11558
  """
11549
11559
  from returnn.tf.native_op import ctc_loss, ctc_loss_viterbi
11550
11560
 
@@ -11567,6 +11577,7 @@ class CtcLossLayer(LayerBase):
11567
11577
  targets=targets.output.copy_as_batch_major().placeholder,
11568
11578
  targets_seq_lens=targets.output.get_sequence_lengths(),
11569
11579
  blank_index=blank_index,
11580
+ label_loop=label_loop,
11570
11581
  )
11571
11582
 
11572
11583
  def get_dep_layers(self):
returnn/tf/native_op.py CHANGED
@@ -1473,12 +1473,14 @@ def fast_baum_welch_staircase(am_scores, seq_lens, **opts):
1473
1473
 
1474
1474
 
1475
1475
  def ctc_loss(
1476
+ *,
1476
1477
  logits,
1477
1478
  logits_seq_lens,
1478
1479
  logits_time_major,
1479
1480
  targets,
1480
1481
  targets_seq_lens,
1481
- ctc_merge_repeated=True,
1482
+ label_loop: Optional[bool] = None,
1483
+ ctc_merge_repeated: Optional[bool] = None,
1482
1484
  logits_normalize=True,
1483
1485
  grad_wrt_softmax_in=True,
1484
1486
  blank_index=-1,
@@ -1493,7 +1495,8 @@ def ctc_loss(
1493
1495
  :param bool logits_time_major:
1494
1496
  :param tf.Tensor targets: batch-major, [batch,time]
1495
1497
  :param tf.Tensor targets_seq_lens: (batch,)
1496
- :param bool ctc_merge_repeated:
1498
+ :param label_loop:
1499
+ :param ctc_merge_repeated: alias for label_loop
1497
1500
  :param bool logits_normalize: apply log_softmax on logits (default).
1498
1501
  if False, you might also set grad_wrt_softmax_in=False
1499
1502
  :param bool grad_wrt_softmax_in: assume ``p(s|x) = softmax(logits)``, and define the gradient w.r.t. logits.
@@ -1504,6 +1507,11 @@ def ctc_loss(
1504
1507
  :return: loss, shape (batch,)
1505
1508
  :rtype: tf.Tensor
1506
1509
  """
1510
+ if ctc_merge_repeated is not None:
1511
+ assert label_loop is None
1512
+ label_loop = ctc_merge_repeated
1513
+ if label_loop is None:
1514
+ label_loop = True
1507
1515
  assert logits.get_shape().ndims == 3 and logits.get_shape().dims[-1].value
1508
1516
  dim = logits.get_shape().dims[-1].value
1509
1517
  if not logits_time_major:
@@ -1520,7 +1528,7 @@ def ctc_loss(
1520
1528
  blank_index += dim
1521
1529
  assert 0 <= blank_index < dim
1522
1530
  edges, weights, start_end_states = get_ctc_fsa_fast_bw(
1523
- targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=ctc_merge_repeated
1531
+ targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
1524
1532
  )
1525
1533
  fwdbwd, obs_scores = fast_baum_welch(
1526
1534
  am_scores=-log_sm, float_idx=seq_mask, edges=edges, weights=weights, start_end_states=start_end_states
@@ -1560,7 +1568,9 @@ def fast_viterbi(am_scores, am_seq_len, edges, weights, start_end_states):
1560
1568
  return alignment, scores
1561
1569
 
1562
1570
 
1563
- def ctc_loss_viterbi(logits, logits_seq_lens, logits_time_major, targets, targets_seq_lens, blank_index=-1):
1571
+ def ctc_loss_viterbi(
1572
+ *, logits, logits_seq_lens, logits_time_major, targets, targets_seq_lens, blank_index=-1, label_loop: bool = True
1573
+ ):
1564
1574
  """
1565
1575
  Similar to :func:`ctc_loss`.
1566
1576
  However, instead of using the full sum, we use the best path (i.e. Viterbi instead of Baum-Welch).
@@ -1572,6 +1582,7 @@ def ctc_loss_viterbi(logits, logits_seq_lens, logits_time_major, targets, target
1572
1582
  :param tf.Tensor targets: batch-major, [batch,time]
1573
1583
  :param tf.Tensor targets_seq_lens: (batch,)
1574
1584
  :param int blank_index: vocab index of the blank symbol
1585
+ :param label_loop:
1575
1586
  :return: loss, shape (batch,)
1576
1587
  :rtype: tf.Tensor
1577
1588
  """
@@ -1585,7 +1596,7 @@ def ctc_loss_viterbi(logits, logits_seq_lens, logits_time_major, targets, target
1585
1596
  blank_index += dim
1586
1597
  assert 0 <= blank_index < dim
1587
1598
  edges, weights, start_end_states = get_ctc_fsa_fast_bw(
1588
- targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index
1599
+ targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
1589
1600
  )
1590
1601
  alignment, scores = fast_viterbi(
1591
1602
  am_scores=log_sm, am_seq_len=logits_seq_lens, edges=edges, weights=weights, start_end_states=start_end_states
returnn/tf/util/basic.py CHANGED
@@ -17,6 +17,7 @@ from tensorflow.python.client import device_lib
17
17
  from tensorflow.python.ops import init_ops
18
18
  from returnn.util import basic as util
19
19
  from returnn.util.basic import NotSpecified, NativeCodeCompiler
20
+ from returnn.util.cuda_env import CudaEnv as _CudaEnvBase
20
21
  from returnn.tensor import Tensor
21
22
  import returnn.tf.compat as tf_compat
22
23
 
@@ -2768,210 +2769,15 @@ def get_tf_gpp_path():
2768
2769
  return _tf_gpp_path
2769
2770
 
2770
2771
 
2771
- class CudaEnv:
2772
+ class CudaEnv(_CudaEnvBase):
2772
2773
  """
2773
- Information about the Nvidia CUDA environment, and library.
2774
- Also path to ``nvcc``, the CUDA compiler.
2774
+ Helper class to get CUDA environment for TF.
2775
2775
  """
2776
2776
 
2777
- _instance = None
2778
- verbose_find_cuda = False
2779
-
2780
- def __init__(self):
2781
- from returnn.util.basic import to_bool
2782
-
2783
- if to_bool(os.environ.get("DISABLE_CUDA", "0")):
2784
- self.cuda_path = None
2785
- if self.verbose_find_cuda:
2786
- print("CUDA disabled via env DISABLE_CUDA.")
2787
- elif os.environ.get("CUDA_VISIBLE_DEVICES", None) in ["", "-1"]:
2788
- self.cuda_path = None
2789
- if self.verbose_find_cuda:
2790
- print(f"CUDA disabled via env CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']!r}.")
2791
- else:
2792
- self.cuda_path = self._find_cuda_path()
2793
- if self.verbose_find_cuda:
2794
- print("CUDA path:", self.cuda_path)
2795
- self._max_compute_capability = None
2796
-
2797
- @classmethod
2798
- def _find_nvcc_in_path(cls):
2799
- """
2800
- :return: yields full path to nvcc
2801
- :rtype: list[str]
2802
- """
2803
- for p in os.environ["PATH"].split(":"):
2804
- pp = "%s/nvcc" % p
2805
- if os.path.exists(pp):
2806
- yield pp
2807
-
2808
- @classmethod
2809
- def _find_lib_in_ld_path(cls):
2810
- """
2811
- :return: yields full path to libcudart.so
2812
- :rtype: list[str]
2813
- """
2814
- from returnn.util.basic import get_ld_paths
2815
-
2816
- for p in get_ld_paths():
2817
- pp = "%s/libcudart.so" % p
2818
- if os.path.exists(pp):
2819
- yield pp
2820
-
2821
- @classmethod
2822
- def _get_lib_dir_name(cls, base_path):
2823
- """
2824
- :return: dir name in base path
2825
- :rtype: str
2826
- """
2827
- from returnn.util.basic import is_64bit_platform, get_ld_paths
2828
-
2829
- for ld_path in get_ld_paths():
2830
- # We also want to allow "lib/x86_64-linux-gnu" for "/usr".
2831
- # However, this logic should not be triggered for incorrect cases.
2832
- # E.g. base_path="/usr" would be the prefix for most LD paths.
2833
- if ld_path.startswith(base_path + "/lib") and os.path.exists("%s/libcudart.so" % ld_path):
2834
- return ld_path[len(base_path) + 1 :]
2835
- if is_64bit_platform():
2836
- return "lib64"
2837
- return "lib"
2838
-
2839
- @classmethod
2840
- def _cuda_path_candidate_via_proc_map_libcudart(cls):
2841
- from returnn.util.basic import find_libcudart_from_runtime
2842
-
2843
- fn = find_libcudart_from_runtime()
2844
- if cls.verbose_find_cuda:
2845
- print("libcudart.so found from /proc/maps:", fn)
2846
- if not fn:
2847
- return None
2848
- # fn is e.g. '/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61',
2849
- # or maybe '/usr/local/cuda-8.0/lib64/libcudart.so'
2850
- p = os.path.dirname(os.path.dirname(fn))
2851
- while not cls._check_valid_cuda_path(p):
2852
- p = os.path.dirname(p)
2853
- if p in ["", "/"]:
2854
- if cls.verbose_find_cuda:
2855
- print(f"Loaded lib {fn} does not seem to be in valid CUDA path.")
2856
- return None
2857
- assert cls._check_valid_cuda_path(p)
2858
- return p
2859
-
2860
- @classmethod
2861
- def _cuda_path_candidates(cls):
2862
- p = cls._cuda_path_candidate_via_proc_map_libcudart()
2863
- if p:
2864
- yield p
2865
- for p in cls._find_nvcc_in_path():
2866
- # Expect p == "/usr/local/cuda-8.0/bin/nvcc" or so.
2867
- postfix = "/bin/nvcc"
2868
- if cls.verbose_find_cuda:
2869
- print("found cuda nvcc (wanted postfix: %r): %s" % (postfix, p))
2870
- if not p.endswith(postfix):
2871
- continue
2872
- yield p[: -len(postfix)] or "/"
2873
- for p in cls._find_lib_in_ld_path():
2874
- # Expect p == "/usr/local/cuda-8.0/lib64/libcudart.so" or so.
2875
- d = "/".join(p.split("/")[:-2]) or "/" # Get "/usr/local/cuda-8.0".
2876
- if cls.verbose_find_cuda:
2877
- print("found cuda lib: %s (path %s)" % (p, d))
2878
- yield d
2879
-
2880
- @classmethod
2881
- def _check_valid_cuda_path(cls, p):
2882
- """
2883
- :param str p: path to CUDA, e.g. "/usr/local/cuda-8.0"
2884
- :return: whether this is a valid CUDA path, i.e. we find all what we need
2885
- :rtype: bool
2886
- """
2887
- if cls.verbose_find_cuda:
2888
- print("check valid CUDA path: %s" % p)
2889
- if not os.path.exists("%s/bin/nvcc" % p):
2890
- return False
2891
- if not os.path.exists("%s/include/cuda.h" % p):
2892
- return False
2893
- if not os.path.exists("%s/%s/libcudart.so" % (p, cls._get_lib_dir_name(p))):
2894
- return False
2895
- return True
2896
-
2897
- @classmethod
2898
- def _find_cuda_path(cls):
2899
- """
2900
- :return: base CUDA path if we find one, otherwise None
2901
- :rtype: str|None
2902
- """
2903
- for p in cls._cuda_path_candidates():
2904
- if cls._check_valid_cuda_path(p):
2905
- return p
2906
- return None
2907
-
2908
- def is_available(self):
2909
- """
2910
- :rtype: bool
2911
- """
2912
- return bool(self.cuda_path)
2913
-
2914
- def get_max_compute_capability(self):
2915
- """
2916
- :return: the highest compute capability supported by nvcc, or float("inf") if not known
2917
- :rtype: float
2918
- """
2919
- if self._max_compute_capability is None:
2920
- cuda_occupancy_path = "%s/include/cuda_occupancy.h" % self.cuda_path
2921
- if os.path.exists(cuda_occupancy_path):
2922
- import re
2923
-
2924
- major, minor = None, 0
2925
- for line in open(cuda_occupancy_path).read().splitlines():
2926
- m = re.match("^#define\\s+__CUDA_OCC_(MAJOR|MINOR)__\\s+([0-9]+)$", line)
2927
- if m:
2928
- s, v = m.groups()
2929
- v = int(v)
2930
- if s == "MAJOR":
2931
- major = v
2932
- else:
2933
- minor = v
2934
- if major:
2935
- self._max_compute_capability = float(major) + float(minor) * 0.1
2936
- if self._max_compute_capability is None:
2937
- self._max_compute_capability = float("inf")
2938
- return self._max_compute_capability
2939
-
2940
- def get_compiler_opts(self):
2941
- """
2942
- :rtype: list[str]
2943
- """
2944
- return [
2945
- "-ccbin",
2946
- get_tf_gcc_path(),
2947
- "-I",
2948
- "%s/targets/x86_64-linux/include" % self.cuda_path,
2949
- "-I",
2950
- "%s/include" % self.cuda_path,
2951
- "-L",
2952
- "%s/%s" % (self.cuda_path, self._get_lib_dir_name(self.cuda_path)),
2953
- "-x",
2954
- "cu",
2955
- "-v",
2956
- ]
2957
-
2958
- def get_compiler_bin(self):
2959
- """
2960
- :return: path
2961
- :rtype: str
2962
- """
2963
- assert self.cuda_path
2964
- return "%s/bin/nvcc" % self.cuda_path
2965
-
2966
- @classmethod
2967
- def get_instance(cls):
2968
- """
2969
- :rtype: CudaEnv
2970
- """
2971
- if cls._instance is not None:
2972
- return cls._instance
2973
- cls._instance = cls()
2974
- return cls._instance
2777
+ @staticmethod
2778
+ def get_cc_bin():
2779
+ """compiler"""
2780
+ return get_tf_gcc_path()
2975
2781
 
2976
2782
 
2977
2783
  class OpCodeCompiler(NativeCodeCompiler):
returnn/torch/engine.py CHANGED
@@ -3,9 +3,11 @@ Main engine for PyTorch
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+
6
7
  from typing import Optional, Any, Union, Callable, Dict, Set
7
8
  from contextlib import nullcontext, ExitStack, contextmanager
8
9
 
10
+ import sys
9
11
  import gc
10
12
  import os
11
13
  import time
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
20
22
  from torch.utils.data import DataLoader
21
23
  from torch import autocast
22
24
  from torch.cuda import amp
25
+ from torch.profiler import record_function
23
26
  import numpy as np
24
27
 
25
28
  import returnn
@@ -404,10 +407,14 @@ class Engine(EngineBase):
404
407
  total_data_size_packed = NumbersDict()
405
408
  total_data_size_padded = NumbersDict()
406
409
 
410
+ prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
411
+ if prof:
412
+ prof.__enter__()
413
+
407
414
  report_prefix = f"ep {self.epoch} train"
408
415
  try:
409
416
  while True:
410
- with torch.no_grad():
417
+ with torch.no_grad(), record_function("data_loading"):
411
418
  extern_data_raw = next(data_iter, None)
412
419
 
413
420
  step_begin_time = time.monotonic()
@@ -485,7 +492,8 @@ class Engine(EngineBase):
485
492
  with (
486
493
  self._ddp_pt_model.no_sync()
487
494
  if (self._ddp_pt_model is not None and not perform_update_step)
488
- else nullcontext()
495
+ else nullcontext(),
496
+ record_function("backward"),
489
497
  ):
490
498
  if self._grad_scaler is not None:
491
499
  self._grad_scaler.scale(total_loss.raw_tensor).backward()
@@ -500,7 +508,8 @@ class Engine(EngineBase):
500
508
 
501
509
  # only update the weights when every gradient accumulation loop ends
502
510
  if perform_update_step:
503
- self._updater.step(grad_scaler=self._grad_scaler)
511
+ with record_function("optimizer_step"):
512
+ self._updater.step(grad_scaler=self._grad_scaler)
504
513
  zero_grad_next_step = perform_update_step
505
514
 
506
515
  if self._torch_distributed_ctx:
@@ -582,10 +591,19 @@ class Engine(EngineBase):
582
591
  self._updater.set_current_train_step(
583
592
  global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
584
593
  )
594
+
595
+ if prof:
596
+ prof.step()
597
+
585
598
  except Exception as exc:
599
+ if prof:
600
+ prof.__exit__(type(exc), exc, exc.__traceback__)
586
601
  help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
587
602
  raise
588
603
 
604
+ if prof:
605
+ prof.__exit__(None, None, None)
606
+
589
607
  elapsed = time.monotonic() - epoch_start_time
590
608
  elapsed_computation_percentage = elapsed_computation_time / elapsed
591
609
  total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
@@ -885,6 +903,7 @@ class Engine(EngineBase):
885
903
  if self._default_float_dtype:
886
904
  stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
887
905
  stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
906
+ stack.enter_context(record_function("model_step"))
888
907
  yield
889
908
 
890
909
  def _run_step(
@@ -1734,3 +1753,101 @@ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str,
1734
1753
  return safetensors_load(filename, device=device)
1735
1754
 
1736
1755
  return torch.load(filename, map_location=device)
1756
+
1757
+
1758
+ class _TorchProfiler:
1759
+ def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
1760
+ self.profiler = profiler
1761
+ self.max_step = max_step
1762
+ self.entered = False
1763
+
1764
+ def __enter__(self):
1765
+ self.profiler.__enter__()
1766
+ self.entered = True
1767
+
1768
+ def __exit__(self, exc_type, exc_val, exc_tb):
1769
+ if not self.entered:
1770
+ return
1771
+ self.entered = False
1772
+ self.profiler.__exit__(exc_type, exc_val, exc_tb)
1773
+
1774
+ if exc_type is None:
1775
+ print(
1776
+ "Torch profiling finished, exporting Chrome trace to torch_profile.json,"
1777
+ " memory timeline to torch_memory_profile.html...",
1778
+ file=log.v2,
1779
+ )
1780
+ self.profiler.export_chrome_trace("torch_profile.json")
1781
+ self.profiler.export_memory_timeline("torch_memory_profile.html")
1782
+
1783
+ print("Exiting program after Torch profiling.", file=log.v2)
1784
+ sys.exit(0)
1785
+
1786
+ def step(self):
1787
+ """step"""
1788
+ self.profiler.step()
1789
+ if self.max_step is not None and self.profiler.step_num > self.max_step:
1790
+ print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
1791
+ self.profiler.stop()
1792
+ self.__exit__(None, None, None)
1793
+
1794
+
1795
+ def _opt_torch_profiler_from_opts(
1796
+ opts: Union[None, int, bool, str, Dict[str, Any]],
1797
+ ) -> Optional[_TorchProfiler]:
1798
+ if isinstance(opts, str):
1799
+ from returnn.util.basic import to_bool
1800
+
1801
+ opts = to_bool(opts)
1802
+
1803
+ if opts is None:
1804
+ return None
1805
+ elif isinstance(opts, (bool, int)):
1806
+ if not opts:
1807
+ return None
1808
+ opts = {}
1809
+ elif isinstance(opts, dict):
1810
+ opts = opts.copy()
1811
+ else:
1812
+ raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
1813
+
1814
+ from torch.profiler import profile, ProfilerActivity, schedule
1815
+
1816
+ print("Using Torch profiler...", file=log.v2)
1817
+
1818
+ prof_max_step = None
1819
+
1820
+ if "activities" not in opts:
1821
+ activities = [ProfilerActivity.CPU]
1822
+ if torch.cuda.is_available():
1823
+ activities += [ProfilerActivity.CUDA]
1824
+ elif torch.xpu.is_available():
1825
+ activities += [ProfilerActivity.XPU]
1826
+ opts["activities"] = activities
1827
+
1828
+ opts.setdefault("profile_memory", True)
1829
+ opts.setdefault("record_shapes", True)
1830
+ opts.setdefault("with_stack", True)
1831
+ opts.setdefault("with_flops", True)
1832
+ # Note: active*repeat are the steps we actually profile.
1833
+ opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
1834
+
1835
+ if isinstance(opts["schedule"], dict):
1836
+ schedule_opts: Dict[str, Any] = opts["schedule"]
1837
+ schedule_opts = schedule_opts.copy()
1838
+ schedule_opts.setdefault("repeat", 0)
1839
+ schedule_opts.setdefault("skip_first", 0)
1840
+ schedule_opts.setdefault("skip_first_wait", 0)
1841
+ opts["schedule"] = schedule(**schedule_opts)
1842
+
1843
+ if schedule_opts["repeat"] > 0:
1844
+ prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
1845
+ "repeat"
1846
+ ]
1847
+ prof_max_step += schedule_opts["skip_first"]
1848
+ if schedule_opts["skip_first_wait"] != 0:
1849
+ prof_max_step -= schedule_opts["wait"]
1850
+ print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
1851
+
1852
+ prof = profile(**opts)
1853
+ return _TorchProfiler(prof, prof_max_step)