returnn 1.20260105.192646__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +1 -1
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +110 -42
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +6 -5
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +19 -0
- returnn/frontend/loss.py +183 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +104 -174
- returnn/native_op.py +36 -31
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +1 -1
- returnn/tf/frontend_layers/_backend.py +3 -1
- returnn/tf/layers/basic.py +13 -2
- returnn/tf/native_op.py +16 -5
- returnn/tf/util/basic.py +7 -201
- returnn/torch/engine.py +120 -3
- returnn/torch/frontend/_backend.py +166 -22
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +3 -1
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +1 -0
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +42 -36
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/tf/layers/basic.py
CHANGED
|
@@ -2741,7 +2741,7 @@ class BooleanMaskLayer(LayerBase):
|
|
|
2741
2741
|
tensor = self.sources[0].output
|
|
2742
2742
|
remaining_dims = [d for d in tensor.dims if d not in dims]
|
|
2743
2743
|
tensor_templ = tensor.copy_template_new_dim_tags(tuple(dims) + tuple(remaining_dims))
|
|
2744
|
-
tensor = tensor.copy_compatible_to(tensor_templ,
|
|
2744
|
+
tensor = tensor.copy_compatible_to(tensor_templ, unbroadcast=True)
|
|
2745
2745
|
mask_templ = mask.output.copy_template_new_dim_tags(new_dim_tags=tuple(dims))
|
|
2746
2746
|
mask_ = mask.output.copy_compatible_to(mask_templ, add_dims=False)
|
|
2747
2747
|
self.output.raw_tensor = tf.boolean_mask(tensor.raw_tensor, mask=mask_.raw_tensor)
|
|
@@ -11538,13 +11538,23 @@ class CtcLossLayer(LayerBase):
|
|
|
11538
11538
|
layer_class = "ctc_loss"
|
|
11539
11539
|
recurrent = True # order matters
|
|
11540
11540
|
|
|
11541
|
-
def __init__(
|
|
11541
|
+
def __init__(
|
|
11542
|
+
self,
|
|
11543
|
+
logits,
|
|
11544
|
+
targets,
|
|
11545
|
+
logits_normalized=False,
|
|
11546
|
+
blank_index=-1,
|
|
11547
|
+
max_approx=False,
|
|
11548
|
+
label_loop: bool = True,
|
|
11549
|
+
**kwargs,
|
|
11550
|
+
):
|
|
11542
11551
|
"""
|
|
11543
11552
|
:param LayerBase logits: (before softmax). shape [B,T,D]
|
|
11544
11553
|
:param LayerBase targets: sparse. shape [B,T]
|
|
11545
11554
|
:param bool logits_normalized: whether the logits are already normalized (e.g. via log-softmax)
|
|
11546
11555
|
:param int blank_index: vocab index of the blank symbol
|
|
11547
11556
|
:param bool max_approx: if True, use max instead of sum over alignments (max approx, Viterbi)
|
|
11557
|
+
:param label_loop:
|
|
11548
11558
|
"""
|
|
11549
11559
|
from returnn.tf.native_op import ctc_loss, ctc_loss_viterbi
|
|
11550
11560
|
|
|
@@ -11567,6 +11577,7 @@ class CtcLossLayer(LayerBase):
|
|
|
11567
11577
|
targets=targets.output.copy_as_batch_major().placeholder,
|
|
11568
11578
|
targets_seq_lens=targets.output.get_sequence_lengths(),
|
|
11569
11579
|
blank_index=blank_index,
|
|
11580
|
+
label_loop=label_loop,
|
|
11570
11581
|
)
|
|
11571
11582
|
|
|
11572
11583
|
def get_dep_layers(self):
|
returnn/tf/native_op.py
CHANGED
|
@@ -1473,12 +1473,14 @@ def fast_baum_welch_staircase(am_scores, seq_lens, **opts):
|
|
|
1473
1473
|
|
|
1474
1474
|
|
|
1475
1475
|
def ctc_loss(
|
|
1476
|
+
*,
|
|
1476
1477
|
logits,
|
|
1477
1478
|
logits_seq_lens,
|
|
1478
1479
|
logits_time_major,
|
|
1479
1480
|
targets,
|
|
1480
1481
|
targets_seq_lens,
|
|
1481
|
-
|
|
1482
|
+
label_loop: Optional[bool] = None,
|
|
1483
|
+
ctc_merge_repeated: Optional[bool] = None,
|
|
1482
1484
|
logits_normalize=True,
|
|
1483
1485
|
grad_wrt_softmax_in=True,
|
|
1484
1486
|
blank_index=-1,
|
|
@@ -1493,7 +1495,8 @@ def ctc_loss(
|
|
|
1493
1495
|
:param bool logits_time_major:
|
|
1494
1496
|
:param tf.Tensor targets: batch-major, [batch,time]
|
|
1495
1497
|
:param tf.Tensor targets_seq_lens: (batch,)
|
|
1496
|
-
:param
|
|
1498
|
+
:param label_loop:
|
|
1499
|
+
:param ctc_merge_repeated: alias for label_loop
|
|
1497
1500
|
:param bool logits_normalize: apply log_softmax on logits (default).
|
|
1498
1501
|
if False, you might also set grad_wrt_softmax_in=False
|
|
1499
1502
|
:param bool grad_wrt_softmax_in: assume ``p(s|x) = softmax(logits)``, and define the gradient w.r.t. logits.
|
|
@@ -1504,6 +1507,11 @@ def ctc_loss(
|
|
|
1504
1507
|
:return: loss, shape (batch,)
|
|
1505
1508
|
:rtype: tf.Tensor
|
|
1506
1509
|
"""
|
|
1510
|
+
if ctc_merge_repeated is not None:
|
|
1511
|
+
assert label_loop is None
|
|
1512
|
+
label_loop = ctc_merge_repeated
|
|
1513
|
+
if label_loop is None:
|
|
1514
|
+
label_loop = True
|
|
1507
1515
|
assert logits.get_shape().ndims == 3 and logits.get_shape().dims[-1].value
|
|
1508
1516
|
dim = logits.get_shape().dims[-1].value
|
|
1509
1517
|
if not logits_time_major:
|
|
@@ -1520,7 +1528,7 @@ def ctc_loss(
|
|
|
1520
1528
|
blank_index += dim
|
|
1521
1529
|
assert 0 <= blank_index < dim
|
|
1522
1530
|
edges, weights, start_end_states = get_ctc_fsa_fast_bw(
|
|
1523
|
-
targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=
|
|
1531
|
+
targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
|
|
1524
1532
|
)
|
|
1525
1533
|
fwdbwd, obs_scores = fast_baum_welch(
|
|
1526
1534
|
am_scores=-log_sm, float_idx=seq_mask, edges=edges, weights=weights, start_end_states=start_end_states
|
|
@@ -1560,7 +1568,9 @@ def fast_viterbi(am_scores, am_seq_len, edges, weights, start_end_states):
|
|
|
1560
1568
|
return alignment, scores
|
|
1561
1569
|
|
|
1562
1570
|
|
|
1563
|
-
def ctc_loss_viterbi(
|
|
1571
|
+
def ctc_loss_viterbi(
|
|
1572
|
+
*, logits, logits_seq_lens, logits_time_major, targets, targets_seq_lens, blank_index=-1, label_loop: bool = True
|
|
1573
|
+
):
|
|
1564
1574
|
"""
|
|
1565
1575
|
Similar to :func:`ctc_loss`.
|
|
1566
1576
|
However, instead of using the full sum, we use the best path (i.e. Viterbi instead of Baum-Welch).
|
|
@@ -1572,6 +1582,7 @@ def ctc_loss_viterbi(logits, logits_seq_lens, logits_time_major, targets, target
|
|
|
1572
1582
|
:param tf.Tensor targets: batch-major, [batch,time]
|
|
1573
1583
|
:param tf.Tensor targets_seq_lens: (batch,)
|
|
1574
1584
|
:param int blank_index: vocab index of the blank symbol
|
|
1585
|
+
:param label_loop:
|
|
1575
1586
|
:return: loss, shape (batch,)
|
|
1576
1587
|
:rtype: tf.Tensor
|
|
1577
1588
|
"""
|
|
@@ -1585,7 +1596,7 @@ def ctc_loss_viterbi(logits, logits_seq_lens, logits_time_major, targets, target
|
|
|
1585
1596
|
blank_index += dim
|
|
1586
1597
|
assert 0 <= blank_index < dim
|
|
1587
1598
|
edges, weights, start_end_states = get_ctc_fsa_fast_bw(
|
|
1588
|
-
targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index
|
|
1599
|
+
targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
|
|
1589
1600
|
)
|
|
1590
1601
|
alignment, scores = fast_viterbi(
|
|
1591
1602
|
am_scores=log_sm, am_seq_len=logits_seq_lens, edges=edges, weights=weights, start_end_states=start_end_states
|
returnn/tf/util/basic.py
CHANGED
|
@@ -17,6 +17,7 @@ from tensorflow.python.client import device_lib
|
|
|
17
17
|
from tensorflow.python.ops import init_ops
|
|
18
18
|
from returnn.util import basic as util
|
|
19
19
|
from returnn.util.basic import NotSpecified, NativeCodeCompiler
|
|
20
|
+
from returnn.util.cuda_env import CudaEnv as _CudaEnvBase
|
|
20
21
|
from returnn.tensor import Tensor
|
|
21
22
|
import returnn.tf.compat as tf_compat
|
|
22
23
|
|
|
@@ -2768,210 +2769,15 @@ def get_tf_gpp_path():
|
|
|
2768
2769
|
return _tf_gpp_path
|
|
2769
2770
|
|
|
2770
2771
|
|
|
2771
|
-
class CudaEnv:
|
|
2772
|
+
class CudaEnv(_CudaEnvBase):
|
|
2772
2773
|
"""
|
|
2773
|
-
|
|
2774
|
-
Also path to ``nvcc``, the CUDA compiler.
|
|
2774
|
+
Helper class to get CUDA environment for TF.
|
|
2775
2775
|
"""
|
|
2776
2776
|
|
|
2777
|
-
|
|
2778
|
-
|
|
2779
|
-
|
|
2780
|
-
|
|
2781
|
-
from returnn.util.basic import to_bool
|
|
2782
|
-
|
|
2783
|
-
if to_bool(os.environ.get("DISABLE_CUDA", "0")):
|
|
2784
|
-
self.cuda_path = None
|
|
2785
|
-
if self.verbose_find_cuda:
|
|
2786
|
-
print("CUDA disabled via env DISABLE_CUDA.")
|
|
2787
|
-
elif os.environ.get("CUDA_VISIBLE_DEVICES", None) in ["", "-1"]:
|
|
2788
|
-
self.cuda_path = None
|
|
2789
|
-
if self.verbose_find_cuda:
|
|
2790
|
-
print(f"CUDA disabled via env CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']!r}.")
|
|
2791
|
-
else:
|
|
2792
|
-
self.cuda_path = self._find_cuda_path()
|
|
2793
|
-
if self.verbose_find_cuda:
|
|
2794
|
-
print("CUDA path:", self.cuda_path)
|
|
2795
|
-
self._max_compute_capability = None
|
|
2796
|
-
|
|
2797
|
-
@classmethod
|
|
2798
|
-
def _find_nvcc_in_path(cls):
|
|
2799
|
-
"""
|
|
2800
|
-
:return: yields full path to nvcc
|
|
2801
|
-
:rtype: list[str]
|
|
2802
|
-
"""
|
|
2803
|
-
for p in os.environ["PATH"].split(":"):
|
|
2804
|
-
pp = "%s/nvcc" % p
|
|
2805
|
-
if os.path.exists(pp):
|
|
2806
|
-
yield pp
|
|
2807
|
-
|
|
2808
|
-
@classmethod
|
|
2809
|
-
def _find_lib_in_ld_path(cls):
|
|
2810
|
-
"""
|
|
2811
|
-
:return: yields full path to libcudart.so
|
|
2812
|
-
:rtype: list[str]
|
|
2813
|
-
"""
|
|
2814
|
-
from returnn.util.basic import get_ld_paths
|
|
2815
|
-
|
|
2816
|
-
for p in get_ld_paths():
|
|
2817
|
-
pp = "%s/libcudart.so" % p
|
|
2818
|
-
if os.path.exists(pp):
|
|
2819
|
-
yield pp
|
|
2820
|
-
|
|
2821
|
-
@classmethod
|
|
2822
|
-
def _get_lib_dir_name(cls, base_path):
|
|
2823
|
-
"""
|
|
2824
|
-
:return: dir name in base path
|
|
2825
|
-
:rtype: str
|
|
2826
|
-
"""
|
|
2827
|
-
from returnn.util.basic import is_64bit_platform, get_ld_paths
|
|
2828
|
-
|
|
2829
|
-
for ld_path in get_ld_paths():
|
|
2830
|
-
# We also want to allow "lib/x86_64-linux-gnu" for "/usr".
|
|
2831
|
-
# However, this logic should not be triggered for incorrect cases.
|
|
2832
|
-
# E.g. base_path="/usr" would be the prefix for most LD paths.
|
|
2833
|
-
if ld_path.startswith(base_path + "/lib") and os.path.exists("%s/libcudart.so" % ld_path):
|
|
2834
|
-
return ld_path[len(base_path) + 1 :]
|
|
2835
|
-
if is_64bit_platform():
|
|
2836
|
-
return "lib64"
|
|
2837
|
-
return "lib"
|
|
2838
|
-
|
|
2839
|
-
@classmethod
|
|
2840
|
-
def _cuda_path_candidate_via_proc_map_libcudart(cls):
|
|
2841
|
-
from returnn.util.basic import find_libcudart_from_runtime
|
|
2842
|
-
|
|
2843
|
-
fn = find_libcudart_from_runtime()
|
|
2844
|
-
if cls.verbose_find_cuda:
|
|
2845
|
-
print("libcudart.so found from /proc/maps:", fn)
|
|
2846
|
-
if not fn:
|
|
2847
|
-
return None
|
|
2848
|
-
# fn is e.g. '/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61',
|
|
2849
|
-
# or maybe '/usr/local/cuda-8.0/lib64/libcudart.so'
|
|
2850
|
-
p = os.path.dirname(os.path.dirname(fn))
|
|
2851
|
-
while not cls._check_valid_cuda_path(p):
|
|
2852
|
-
p = os.path.dirname(p)
|
|
2853
|
-
if p in ["", "/"]:
|
|
2854
|
-
if cls.verbose_find_cuda:
|
|
2855
|
-
print(f"Loaded lib {fn} does not seem to be in valid CUDA path.")
|
|
2856
|
-
return None
|
|
2857
|
-
assert cls._check_valid_cuda_path(p)
|
|
2858
|
-
return p
|
|
2859
|
-
|
|
2860
|
-
@classmethod
|
|
2861
|
-
def _cuda_path_candidates(cls):
|
|
2862
|
-
p = cls._cuda_path_candidate_via_proc_map_libcudart()
|
|
2863
|
-
if p:
|
|
2864
|
-
yield p
|
|
2865
|
-
for p in cls._find_nvcc_in_path():
|
|
2866
|
-
# Expect p == "/usr/local/cuda-8.0/bin/nvcc" or so.
|
|
2867
|
-
postfix = "/bin/nvcc"
|
|
2868
|
-
if cls.verbose_find_cuda:
|
|
2869
|
-
print("found cuda nvcc (wanted postfix: %r): %s" % (postfix, p))
|
|
2870
|
-
if not p.endswith(postfix):
|
|
2871
|
-
continue
|
|
2872
|
-
yield p[: -len(postfix)] or "/"
|
|
2873
|
-
for p in cls._find_lib_in_ld_path():
|
|
2874
|
-
# Expect p == "/usr/local/cuda-8.0/lib64/libcudart.so" or so.
|
|
2875
|
-
d = "/".join(p.split("/")[:-2]) or "/" # Get "/usr/local/cuda-8.0".
|
|
2876
|
-
if cls.verbose_find_cuda:
|
|
2877
|
-
print("found cuda lib: %s (path %s)" % (p, d))
|
|
2878
|
-
yield d
|
|
2879
|
-
|
|
2880
|
-
@classmethod
|
|
2881
|
-
def _check_valid_cuda_path(cls, p):
|
|
2882
|
-
"""
|
|
2883
|
-
:param str p: path to CUDA, e.g. "/usr/local/cuda-8.0"
|
|
2884
|
-
:return: whether this is a valid CUDA path, i.e. we find all what we need
|
|
2885
|
-
:rtype: bool
|
|
2886
|
-
"""
|
|
2887
|
-
if cls.verbose_find_cuda:
|
|
2888
|
-
print("check valid CUDA path: %s" % p)
|
|
2889
|
-
if not os.path.exists("%s/bin/nvcc" % p):
|
|
2890
|
-
return False
|
|
2891
|
-
if not os.path.exists("%s/include/cuda.h" % p):
|
|
2892
|
-
return False
|
|
2893
|
-
if not os.path.exists("%s/%s/libcudart.so" % (p, cls._get_lib_dir_name(p))):
|
|
2894
|
-
return False
|
|
2895
|
-
return True
|
|
2896
|
-
|
|
2897
|
-
@classmethod
|
|
2898
|
-
def _find_cuda_path(cls):
|
|
2899
|
-
"""
|
|
2900
|
-
:return: base CUDA path if we find one, otherwise None
|
|
2901
|
-
:rtype: str|None
|
|
2902
|
-
"""
|
|
2903
|
-
for p in cls._cuda_path_candidates():
|
|
2904
|
-
if cls._check_valid_cuda_path(p):
|
|
2905
|
-
return p
|
|
2906
|
-
return None
|
|
2907
|
-
|
|
2908
|
-
def is_available(self):
|
|
2909
|
-
"""
|
|
2910
|
-
:rtype: bool
|
|
2911
|
-
"""
|
|
2912
|
-
return bool(self.cuda_path)
|
|
2913
|
-
|
|
2914
|
-
def get_max_compute_capability(self):
|
|
2915
|
-
"""
|
|
2916
|
-
:return: the highest compute capability supported by nvcc, or float("inf") if not known
|
|
2917
|
-
:rtype: float
|
|
2918
|
-
"""
|
|
2919
|
-
if self._max_compute_capability is None:
|
|
2920
|
-
cuda_occupancy_path = "%s/include/cuda_occupancy.h" % self.cuda_path
|
|
2921
|
-
if os.path.exists(cuda_occupancy_path):
|
|
2922
|
-
import re
|
|
2923
|
-
|
|
2924
|
-
major, minor = None, 0
|
|
2925
|
-
for line in open(cuda_occupancy_path).read().splitlines():
|
|
2926
|
-
m = re.match("^#define\\s+__CUDA_OCC_(MAJOR|MINOR)__\\s+([0-9]+)$", line)
|
|
2927
|
-
if m:
|
|
2928
|
-
s, v = m.groups()
|
|
2929
|
-
v = int(v)
|
|
2930
|
-
if s == "MAJOR":
|
|
2931
|
-
major = v
|
|
2932
|
-
else:
|
|
2933
|
-
minor = v
|
|
2934
|
-
if major:
|
|
2935
|
-
self._max_compute_capability = float(major) + float(minor) * 0.1
|
|
2936
|
-
if self._max_compute_capability is None:
|
|
2937
|
-
self._max_compute_capability = float("inf")
|
|
2938
|
-
return self._max_compute_capability
|
|
2939
|
-
|
|
2940
|
-
def get_compiler_opts(self):
|
|
2941
|
-
"""
|
|
2942
|
-
:rtype: list[str]
|
|
2943
|
-
"""
|
|
2944
|
-
return [
|
|
2945
|
-
"-ccbin",
|
|
2946
|
-
get_tf_gcc_path(),
|
|
2947
|
-
"-I",
|
|
2948
|
-
"%s/targets/x86_64-linux/include" % self.cuda_path,
|
|
2949
|
-
"-I",
|
|
2950
|
-
"%s/include" % self.cuda_path,
|
|
2951
|
-
"-L",
|
|
2952
|
-
"%s/%s" % (self.cuda_path, self._get_lib_dir_name(self.cuda_path)),
|
|
2953
|
-
"-x",
|
|
2954
|
-
"cu",
|
|
2955
|
-
"-v",
|
|
2956
|
-
]
|
|
2957
|
-
|
|
2958
|
-
def get_compiler_bin(self):
|
|
2959
|
-
"""
|
|
2960
|
-
:return: path
|
|
2961
|
-
:rtype: str
|
|
2962
|
-
"""
|
|
2963
|
-
assert self.cuda_path
|
|
2964
|
-
return "%s/bin/nvcc" % self.cuda_path
|
|
2965
|
-
|
|
2966
|
-
@classmethod
|
|
2967
|
-
def get_instance(cls):
|
|
2968
|
-
"""
|
|
2969
|
-
:rtype: CudaEnv
|
|
2970
|
-
"""
|
|
2971
|
-
if cls._instance is not None:
|
|
2972
|
-
return cls._instance
|
|
2973
|
-
cls._instance = cls()
|
|
2974
|
-
return cls._instance
|
|
2777
|
+
@staticmethod
|
|
2778
|
+
def get_cc_bin():
|
|
2779
|
+
"""compiler"""
|
|
2780
|
+
return get_tf_gcc_path()
|
|
2975
2781
|
|
|
2976
2782
|
|
|
2977
2783
|
class OpCodeCompiler(NativeCodeCompiler):
|
returnn/torch/engine.py
CHANGED
|
@@ -3,9 +3,11 @@ Main engine for PyTorch
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
+
|
|
6
7
|
from typing import Optional, Any, Union, Callable, Dict, Set
|
|
7
8
|
from contextlib import nullcontext, ExitStack, contextmanager
|
|
8
9
|
|
|
10
|
+
import sys
|
|
9
11
|
import gc
|
|
10
12
|
import os
|
|
11
13
|
import time
|
|
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
|
|
|
20
22
|
from torch.utils.data import DataLoader
|
|
21
23
|
from torch import autocast
|
|
22
24
|
from torch.cuda import amp
|
|
25
|
+
from torch.profiler import record_function
|
|
23
26
|
import numpy as np
|
|
24
27
|
|
|
25
28
|
import returnn
|
|
@@ -404,10 +407,14 @@ class Engine(EngineBase):
|
|
|
404
407
|
total_data_size_packed = NumbersDict()
|
|
405
408
|
total_data_size_padded = NumbersDict()
|
|
406
409
|
|
|
410
|
+
prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
|
|
411
|
+
if prof:
|
|
412
|
+
prof.__enter__()
|
|
413
|
+
|
|
407
414
|
report_prefix = f"ep {self.epoch} train"
|
|
408
415
|
try:
|
|
409
416
|
while True:
|
|
410
|
-
with torch.no_grad():
|
|
417
|
+
with torch.no_grad(), record_function("data_loading"):
|
|
411
418
|
extern_data_raw = next(data_iter, None)
|
|
412
419
|
|
|
413
420
|
step_begin_time = time.monotonic()
|
|
@@ -485,7 +492,8 @@ class Engine(EngineBase):
|
|
|
485
492
|
with (
|
|
486
493
|
self._ddp_pt_model.no_sync()
|
|
487
494
|
if (self._ddp_pt_model is not None and not perform_update_step)
|
|
488
|
-
else nullcontext()
|
|
495
|
+
else nullcontext(),
|
|
496
|
+
record_function("backward"),
|
|
489
497
|
):
|
|
490
498
|
if self._grad_scaler is not None:
|
|
491
499
|
self._grad_scaler.scale(total_loss.raw_tensor).backward()
|
|
@@ -500,7 +508,8 @@ class Engine(EngineBase):
|
|
|
500
508
|
|
|
501
509
|
# only update the weights when every gradient accumulation loop ends
|
|
502
510
|
if perform_update_step:
|
|
503
|
-
|
|
511
|
+
with record_function("optimizer_step"):
|
|
512
|
+
self._updater.step(grad_scaler=self._grad_scaler)
|
|
504
513
|
zero_grad_next_step = perform_update_step
|
|
505
514
|
|
|
506
515
|
if self._torch_distributed_ctx:
|
|
@@ -582,10 +591,19 @@ class Engine(EngineBase):
|
|
|
582
591
|
self._updater.set_current_train_step(
|
|
583
592
|
global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
|
|
584
593
|
)
|
|
594
|
+
|
|
595
|
+
if prof:
|
|
596
|
+
prof.step()
|
|
597
|
+
|
|
585
598
|
except Exception as exc:
|
|
599
|
+
if prof:
|
|
600
|
+
prof.__exit__(type(exc), exc, exc.__traceback__)
|
|
586
601
|
help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
|
|
587
602
|
raise
|
|
588
603
|
|
|
604
|
+
if prof:
|
|
605
|
+
prof.__exit__(None, None, None)
|
|
606
|
+
|
|
589
607
|
elapsed = time.monotonic() - epoch_start_time
|
|
590
608
|
elapsed_computation_percentage = elapsed_computation_time / elapsed
|
|
591
609
|
total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
|
|
@@ -885,6 +903,7 @@ class Engine(EngineBase):
|
|
|
885
903
|
if self._default_float_dtype:
|
|
886
904
|
stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
|
|
887
905
|
stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
|
|
906
|
+
stack.enter_context(record_function("model_step"))
|
|
888
907
|
yield
|
|
889
908
|
|
|
890
909
|
def _run_step(
|
|
@@ -1734,3 +1753,101 @@ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str,
|
|
|
1734
1753
|
return safetensors_load(filename, device=device)
|
|
1735
1754
|
|
|
1736
1755
|
return torch.load(filename, map_location=device)
|
|
1756
|
+
|
|
1757
|
+
|
|
1758
|
+
class _TorchProfiler:
|
|
1759
|
+
def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
|
|
1760
|
+
self.profiler = profiler
|
|
1761
|
+
self.max_step = max_step
|
|
1762
|
+
self.entered = False
|
|
1763
|
+
|
|
1764
|
+
def __enter__(self):
|
|
1765
|
+
self.profiler.__enter__()
|
|
1766
|
+
self.entered = True
|
|
1767
|
+
|
|
1768
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1769
|
+
if not self.entered:
|
|
1770
|
+
return
|
|
1771
|
+
self.entered = False
|
|
1772
|
+
self.profiler.__exit__(exc_type, exc_val, exc_tb)
|
|
1773
|
+
|
|
1774
|
+
if exc_type is None:
|
|
1775
|
+
print(
|
|
1776
|
+
"Torch profiling finished, exporting Chrome trace to torch_profile.json,"
|
|
1777
|
+
" memory timeline to torch_memory_profile.html...",
|
|
1778
|
+
file=log.v2,
|
|
1779
|
+
)
|
|
1780
|
+
self.profiler.export_chrome_trace("torch_profile.json")
|
|
1781
|
+
self.profiler.export_memory_timeline("torch_memory_profile.html")
|
|
1782
|
+
|
|
1783
|
+
print("Exiting program after Torch profiling.", file=log.v2)
|
|
1784
|
+
sys.exit(0)
|
|
1785
|
+
|
|
1786
|
+
def step(self):
|
|
1787
|
+
"""step"""
|
|
1788
|
+
self.profiler.step()
|
|
1789
|
+
if self.max_step is not None and self.profiler.step_num > self.max_step:
|
|
1790
|
+
print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
|
|
1791
|
+
self.profiler.stop()
|
|
1792
|
+
self.__exit__(None, None, None)
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
def _opt_torch_profiler_from_opts(
|
|
1796
|
+
opts: Union[None, int, bool, str, Dict[str, Any]],
|
|
1797
|
+
) -> Optional[_TorchProfiler]:
|
|
1798
|
+
if isinstance(opts, str):
|
|
1799
|
+
from returnn.util.basic import to_bool
|
|
1800
|
+
|
|
1801
|
+
opts = to_bool(opts)
|
|
1802
|
+
|
|
1803
|
+
if opts is None:
|
|
1804
|
+
return None
|
|
1805
|
+
elif isinstance(opts, (bool, int)):
|
|
1806
|
+
if not opts:
|
|
1807
|
+
return None
|
|
1808
|
+
opts = {}
|
|
1809
|
+
elif isinstance(opts, dict):
|
|
1810
|
+
opts = opts.copy()
|
|
1811
|
+
else:
|
|
1812
|
+
raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
|
|
1813
|
+
|
|
1814
|
+
from torch.profiler import profile, ProfilerActivity, schedule
|
|
1815
|
+
|
|
1816
|
+
print("Using Torch profiler...", file=log.v2)
|
|
1817
|
+
|
|
1818
|
+
prof_max_step = None
|
|
1819
|
+
|
|
1820
|
+
if "activities" not in opts:
|
|
1821
|
+
activities = [ProfilerActivity.CPU]
|
|
1822
|
+
if torch.cuda.is_available():
|
|
1823
|
+
activities += [ProfilerActivity.CUDA]
|
|
1824
|
+
elif torch.xpu.is_available():
|
|
1825
|
+
activities += [ProfilerActivity.XPU]
|
|
1826
|
+
opts["activities"] = activities
|
|
1827
|
+
|
|
1828
|
+
opts.setdefault("profile_memory", True)
|
|
1829
|
+
opts.setdefault("record_shapes", True)
|
|
1830
|
+
opts.setdefault("with_stack", True)
|
|
1831
|
+
opts.setdefault("with_flops", True)
|
|
1832
|
+
# Note: active*repeat are the steps we actually profile.
|
|
1833
|
+
opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
|
|
1834
|
+
|
|
1835
|
+
if isinstance(opts["schedule"], dict):
|
|
1836
|
+
schedule_opts: Dict[str, Any] = opts["schedule"]
|
|
1837
|
+
schedule_opts = schedule_opts.copy()
|
|
1838
|
+
schedule_opts.setdefault("repeat", 0)
|
|
1839
|
+
schedule_opts.setdefault("skip_first", 0)
|
|
1840
|
+
schedule_opts.setdefault("skip_first_wait", 0)
|
|
1841
|
+
opts["schedule"] = schedule(**schedule_opts)
|
|
1842
|
+
|
|
1843
|
+
if schedule_opts["repeat"] > 0:
|
|
1844
|
+
prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
|
|
1845
|
+
"repeat"
|
|
1846
|
+
]
|
|
1847
|
+
prof_max_step += schedule_opts["skip_first"]
|
|
1848
|
+
if schedule_opts["skip_first_wait"] != 0:
|
|
1849
|
+
prof_max_step -= schedule_opts["wait"]
|
|
1850
|
+
print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
|
|
1851
|
+
|
|
1852
|
+
prof = profile(**opts)
|
|
1853
|
+
return _TorchProfiler(prof, prof_max_step)
|