triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
triton/testing.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
import statistics
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from typing import Any, Dict, List
|
|
9
|
+
from . import language as tl
|
|
10
|
+
from . import runtime
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def nvsmi(attrs):
|
|
14
|
+
attrs = ','.join(attrs)
|
|
15
|
+
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
|
16
|
+
out = subprocess.check_output(cmd)
|
|
17
|
+
ret = out.decode(sys.stdout.encoding).split(',')
|
|
18
|
+
ret = [int(x) for x in ret]
|
|
19
|
+
return ret
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# pure Python implementation of np.quantile/torch.quantile
|
|
23
|
+
# to avoid unnecessary runtime dependency on numpy/torch
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _quantile(a, q):
|
|
27
|
+
n = len(a)
|
|
28
|
+
a = sorted(a)
|
|
29
|
+
|
|
30
|
+
def get_quantile(q):
|
|
31
|
+
if not (0 <= q <= 1):
|
|
32
|
+
raise ValueError("Quantiles must be in the range [0, 1]")
|
|
33
|
+
point = q * (n - 1)
|
|
34
|
+
lower = math.floor(point)
|
|
35
|
+
upper = math.ceil(point)
|
|
36
|
+
t = point - lower
|
|
37
|
+
return (1 - t) * a[lower] + t * a[upper]
|
|
38
|
+
|
|
39
|
+
return [get_quantile(q) for q in q]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _summarize_statistics(times, quantiles, return_mode):
|
|
43
|
+
if quantiles is not None:
|
|
44
|
+
ret = _quantile(times, quantiles)
|
|
45
|
+
if len(ret) == 1:
|
|
46
|
+
ret = ret[0]
|
|
47
|
+
return ret
|
|
48
|
+
if return_mode == "all":
|
|
49
|
+
return times
|
|
50
|
+
elif return_mode == "min":
|
|
51
|
+
return min(times)
|
|
52
|
+
elif return_mode == "max":
|
|
53
|
+
return max(times)
|
|
54
|
+
elif return_mode == "mean":
|
|
55
|
+
return statistics.mean(times)
|
|
56
|
+
elif return_mode == "median":
|
|
57
|
+
return statistics.median(times)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
|
|
61
|
+
"""
|
|
62
|
+
Benchmark the runtime of the provided function.
|
|
63
|
+
|
|
64
|
+
:param fn: Function to benchmark
|
|
65
|
+
:type fn: Callable
|
|
66
|
+
:param rep: Repetition time (in ms)
|
|
67
|
+
:type rep: int
|
|
68
|
+
:param grad_to_none: Reset the gradient of the provided tensor to None
|
|
69
|
+
:type grad_to_none: torch.tensor, optional
|
|
70
|
+
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
|
|
71
|
+
:type return_mode: str
|
|
72
|
+
"""
|
|
73
|
+
import torch
|
|
74
|
+
assert return_mode in ["min", "max", "mean", "median", "all"]
|
|
75
|
+
|
|
76
|
+
with torch.cuda.stream(torch.cuda.Stream()):
|
|
77
|
+
# warmup
|
|
78
|
+
fn()
|
|
79
|
+
if grad_to_none is not None:
|
|
80
|
+
for x in grad_to_none:
|
|
81
|
+
x.detach_()
|
|
82
|
+
x.requires_grad_(True)
|
|
83
|
+
x.grad = None
|
|
84
|
+
# step 1 - we estimate the amount of time the kernel call takes
|
|
85
|
+
# NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
|
|
86
|
+
# but it is probably good enough
|
|
87
|
+
# NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive,
|
|
88
|
+
# ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2
|
|
89
|
+
# cache flush).
|
|
90
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
91
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
92
|
+
start_event.record()
|
|
93
|
+
for _ in range(5):
|
|
94
|
+
fn()
|
|
95
|
+
end_event.record()
|
|
96
|
+
torch.cuda.synchronize()
|
|
97
|
+
estimate_ms = start_event.elapsed_time(end_event) / 5
|
|
98
|
+
# Rewrite to avoid possible division by 0 issues with fast benchmarks
|
|
99
|
+
if estimate_ms == 0:
|
|
100
|
+
n_repeat = 1000
|
|
101
|
+
else:
|
|
102
|
+
n_repeat = max(1, int(rep / estimate_ms))
|
|
103
|
+
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
|
|
104
|
+
# host overhead
|
|
105
|
+
g = torch.cuda.CUDAGraph()
|
|
106
|
+
with torch.cuda.graph(g):
|
|
107
|
+
for _ in range(n_repeat):
|
|
108
|
+
if grad_to_none is not None:
|
|
109
|
+
for x in grad_to_none:
|
|
110
|
+
x.grad = None
|
|
111
|
+
fn()
|
|
112
|
+
torch.cuda.synchronize()
|
|
113
|
+
# measure time and return
|
|
114
|
+
ret = []
|
|
115
|
+
n_retries = 10
|
|
116
|
+
for _ in range(n_retries):
|
|
117
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
118
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
119
|
+
start_event.record()
|
|
120
|
+
g.replay()
|
|
121
|
+
end_event.record()
|
|
122
|
+
torch.cuda.synchronize()
|
|
123
|
+
ret += [start_event.elapsed_time(end_event) / n_repeat]
|
|
124
|
+
return _summarize_statistics(ret, quantiles, return_mode)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
|
|
128
|
+
"""
|
|
129
|
+
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
|
130
|
+
the 20-th and 80-th performance percentile.
|
|
131
|
+
|
|
132
|
+
:param fn: Function to benchmark
|
|
133
|
+
:type fn: Callable
|
|
134
|
+
:param warmup: Warmup time (in ms)
|
|
135
|
+
:type warmup: int
|
|
136
|
+
:param rep: Repetition time (in ms)
|
|
137
|
+
:type rep: int
|
|
138
|
+
:param grad_to_none: Reset the gradient of the provided tensor to None
|
|
139
|
+
:type grad_to_none: torch.tensor, optional
|
|
140
|
+
:param quantiles: Performance percentile to return in addition to the median.
|
|
141
|
+
:type quantiles: list[float], optional
|
|
142
|
+
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
|
|
143
|
+
:type return_mode: str
|
|
144
|
+
"""
|
|
145
|
+
assert return_mode in ["min", "max", "mean", "median", "all"]
|
|
146
|
+
|
|
147
|
+
di = runtime.driver.active.get_device_interface()
|
|
148
|
+
|
|
149
|
+
fn()
|
|
150
|
+
di.synchronize()
|
|
151
|
+
|
|
152
|
+
cache = runtime.driver.active.get_empty_cache_for_benchmark()
|
|
153
|
+
|
|
154
|
+
# Estimate the runtime of the function
|
|
155
|
+
start_event = di.Event(enable_timing=True)
|
|
156
|
+
end_event = di.Event(enable_timing=True)
|
|
157
|
+
start_event.record()
|
|
158
|
+
for _ in range(5):
|
|
159
|
+
runtime.driver.active.clear_cache(cache)
|
|
160
|
+
fn()
|
|
161
|
+
end_event.record()
|
|
162
|
+
di.synchronize()
|
|
163
|
+
estimate_ms = start_event.elapsed_time(end_event) / 5
|
|
164
|
+
|
|
165
|
+
# compute number of warmup and repeat
|
|
166
|
+
n_warmup = max(1, int(warmup / estimate_ms))
|
|
167
|
+
n_repeat = max(1, int(rep / estimate_ms))
|
|
168
|
+
start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
|
|
169
|
+
end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
|
|
170
|
+
# Warm-up
|
|
171
|
+
for _ in range(n_warmup):
|
|
172
|
+
fn()
|
|
173
|
+
# Benchmark
|
|
174
|
+
for i in range(n_repeat):
|
|
175
|
+
# we don't want `fn` to accumulate gradient values
|
|
176
|
+
# if it contains a backward pass. So we clear the
|
|
177
|
+
# provided gradients
|
|
178
|
+
if grad_to_none is not None:
|
|
179
|
+
for x in grad_to_none:
|
|
180
|
+
x.grad = None
|
|
181
|
+
# we clear the L2 cache before each run
|
|
182
|
+
runtime.driver.active.clear_cache(cache)
|
|
183
|
+
# record time of `fn`
|
|
184
|
+
start_event[i].record()
|
|
185
|
+
fn()
|
|
186
|
+
end_event[i].record()
|
|
187
|
+
# Record clocks
|
|
188
|
+
di.synchronize()
|
|
189
|
+
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
|
|
190
|
+
return _summarize_statistics(times, quantiles, return_mode)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def assert_close(x, y, atol=None, rtol=None, err_msg=''):
|
|
194
|
+
"""
|
|
195
|
+
Asserts that two inputs are close within a certain tolerance.
|
|
196
|
+
|
|
197
|
+
:param x: The first input.
|
|
198
|
+
:type x: scala, list, numpy.ndarray, or torch.Tensor
|
|
199
|
+
:param y: The second input.
|
|
200
|
+
:type y: scala, list, numpy.ndarray, or torch.Tensor
|
|
201
|
+
:param atol: The absolute tolerance. Default value is 1e-2.
|
|
202
|
+
:type atol: float, optional
|
|
203
|
+
:param rtol: The relative tolerance. Default value is 0.
|
|
204
|
+
:type rtol: float, optional
|
|
205
|
+
:param err_msg: The error message to use if the assertion fails.
|
|
206
|
+
:type err_msg: str
|
|
207
|
+
"""
|
|
208
|
+
import numpy as np
|
|
209
|
+
import torch
|
|
210
|
+
|
|
211
|
+
# canonicalize arguments to be tensors
|
|
212
|
+
if not isinstance(x, torch.Tensor):
|
|
213
|
+
x = torch.tensor(x)
|
|
214
|
+
if not isinstance(y, torch.Tensor):
|
|
215
|
+
y = torch.tensor(y)
|
|
216
|
+
# absolute tolerance
|
|
217
|
+
if atol is None:
|
|
218
|
+
atol = 1e-2
|
|
219
|
+
atol = atol(x.dtype) if callable(atol) else atol
|
|
220
|
+
# relative tolerance hook
|
|
221
|
+
if rtol is None:
|
|
222
|
+
rtol = 0.
|
|
223
|
+
rtol = rtol(x.dtype) if callable(rtol) else rtol
|
|
224
|
+
# we use numpy instead of pytorch
|
|
225
|
+
# as it seems more memory efficient
|
|
226
|
+
# pytorch tends to oom on large tensors
|
|
227
|
+
if isinstance(x, torch.Tensor):
|
|
228
|
+
if x.dtype == torch.bfloat16:
|
|
229
|
+
x = x.float()
|
|
230
|
+
x = x.cpu().detach().numpy()
|
|
231
|
+
if isinstance(y, torch.Tensor):
|
|
232
|
+
if y.dtype == torch.bfloat16:
|
|
233
|
+
y = y.float()
|
|
234
|
+
y = y.cpu().detach().numpy()
|
|
235
|
+
# we handle size==1 case separately as we can
|
|
236
|
+
# provide better error message there
|
|
237
|
+
if x.size > 1 or y.size > 1:
|
|
238
|
+
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
|
|
239
|
+
return
|
|
240
|
+
if not np.allclose(x, y, atol=atol, rtol=rtol):
|
|
241
|
+
raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class Benchmark:
|
|
245
|
+
"""
|
|
246
|
+
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(
|
|
250
|
+
self,
|
|
251
|
+
x_names: List[str],
|
|
252
|
+
x_vals: List[Any],
|
|
253
|
+
line_arg: str,
|
|
254
|
+
line_vals: List[Any],
|
|
255
|
+
line_names: List[str],
|
|
256
|
+
plot_name: str,
|
|
257
|
+
args: Dict[str, Any],
|
|
258
|
+
xlabel: str = '',
|
|
259
|
+
ylabel: str = '',
|
|
260
|
+
x_log: bool = False,
|
|
261
|
+
y_log: bool = False,
|
|
262
|
+
styles=None,
|
|
263
|
+
):
|
|
264
|
+
"""
|
|
265
|
+
Constructor.
|
|
266
|
+
x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
|
|
267
|
+
of scalars and there are multiple x_names, all arguments will have the same value.
|
|
268
|
+
If x_vals is a list of tuples/lists, each element should have the same length as
|
|
269
|
+
x_names.
|
|
270
|
+
|
|
271
|
+
:param x_names: Name of the arguments that should appear on the x axis of the plot.
|
|
272
|
+
:type x_names: List[str]
|
|
273
|
+
:param x_vals: List of values to use for the arguments in :code:`x_names`.
|
|
274
|
+
:type x_vals: List[Any]
|
|
275
|
+
:param line_arg: Argument name for which different values correspond to different lines in the plot.
|
|
276
|
+
:type line_arg: str
|
|
277
|
+
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
|
|
278
|
+
:type line_vals: List[Any]
|
|
279
|
+
:param line_names: Label names for the different lines.
|
|
280
|
+
:type line_names: List[str]
|
|
281
|
+
:param plot_name: Name of the plot.
|
|
282
|
+
:type plot_name: str
|
|
283
|
+
:param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
|
|
284
|
+
:type args: Dict[str, Any]
|
|
285
|
+
:param xlabel: Label for the x axis of the plot.
|
|
286
|
+
:type xlabel: str, optional
|
|
287
|
+
:param ylabel: Label for the y axis of the plot.
|
|
288
|
+
:type ylabel: str, optional
|
|
289
|
+
:param x_log: Whether the x axis should be log scale.
|
|
290
|
+
:type x_log: bool, optional
|
|
291
|
+
:param y_log: Whether the y axis should be log scale.
|
|
292
|
+
:type y_log: bool, optional
|
|
293
|
+
:param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle.
|
|
294
|
+
:type styles: list[tuple[str, str]]
|
|
295
|
+
"""
|
|
296
|
+
self.x_names = x_names
|
|
297
|
+
self.x_vals = x_vals
|
|
298
|
+
self.x_log = x_log
|
|
299
|
+
self.line_arg = line_arg
|
|
300
|
+
self.line_vals = line_vals
|
|
301
|
+
self.line_names = line_names
|
|
302
|
+
self.y_log = y_log
|
|
303
|
+
self.styles = styles
|
|
304
|
+
# plot info
|
|
305
|
+
self.xlabel = xlabel
|
|
306
|
+
self.ylabel = ylabel
|
|
307
|
+
self.plot_name = plot_name
|
|
308
|
+
self.args = args
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class Mark:
|
|
312
|
+
|
|
313
|
+
def __init__(self, fn, benchmarks):
|
|
314
|
+
self.fn = fn
|
|
315
|
+
self.benchmarks = benchmarks
|
|
316
|
+
|
|
317
|
+
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
|
|
318
|
+
save_precision=6, **kwrags):
|
|
319
|
+
import os
|
|
320
|
+
|
|
321
|
+
import matplotlib.pyplot as plt
|
|
322
|
+
import pandas as pd
|
|
323
|
+
y_mean = bench.line_names
|
|
324
|
+
y_min = [f'{x}-min' for x in bench.line_names]
|
|
325
|
+
y_max = [f'{x}-max' for x in bench.line_names]
|
|
326
|
+
x_names = list(bench.x_names)
|
|
327
|
+
df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
|
|
328
|
+
for x in bench.x_vals:
|
|
329
|
+
# x can be a single value or a sequence of values.
|
|
330
|
+
if not isinstance(x, (list, tuple)):
|
|
331
|
+
x = [x for _ in x_names]
|
|
332
|
+
|
|
333
|
+
if len(x) != len(x_names):
|
|
334
|
+
raise ValueError(f"Expected {len(x_names)} values, got {x}")
|
|
335
|
+
x_args = dict(zip(x_names, x))
|
|
336
|
+
|
|
337
|
+
row_mean, row_min, row_max = [], [], []
|
|
338
|
+
for y in bench.line_vals:
|
|
339
|
+
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
|
|
340
|
+
try:
|
|
341
|
+
y_mean, y_min, y_max = ret
|
|
342
|
+
except TypeError:
|
|
343
|
+
y_mean, y_min, y_max = ret, None, None
|
|
344
|
+
row_mean += [y_mean]
|
|
345
|
+
row_min += [y_min]
|
|
346
|
+
row_max += [y_max]
|
|
347
|
+
df.loc[len(df)] = list(x) + row_mean + row_min + row_max
|
|
348
|
+
|
|
349
|
+
if bench.plot_name:
|
|
350
|
+
plt.figure()
|
|
351
|
+
ax = plt.subplot()
|
|
352
|
+
# Plot first x value on x axis if there are multiple.
|
|
353
|
+
first_x = x_names[0]
|
|
354
|
+
for i, y in enumerate(bench.line_names):
|
|
355
|
+
y_min, y_max = df[y + '-min'], df[y + '-max']
|
|
356
|
+
col = bench.styles[i][0] if bench.styles else None
|
|
357
|
+
sty = bench.styles[i][1] if bench.styles else None
|
|
358
|
+
ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
|
|
359
|
+
if not y_min.isnull().all() and not y_max.isnull().all():
|
|
360
|
+
y_min = y_min.astype(float)
|
|
361
|
+
y_max = y_max.astype(float)
|
|
362
|
+
ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
|
|
363
|
+
ax.legend()
|
|
364
|
+
ax.set_xlabel(bench.xlabel or first_x)
|
|
365
|
+
ax.set_ylabel(bench.ylabel)
|
|
366
|
+
# ax.set_title(bench.plot_name)
|
|
367
|
+
ax.set_xscale("log" if bench.x_log else "linear")
|
|
368
|
+
ax.set_yscale("log" if bench.y_log else "linear")
|
|
369
|
+
if show_plots:
|
|
370
|
+
plt.show()
|
|
371
|
+
if save_path:
|
|
372
|
+
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
|
373
|
+
df = df[x_names + bench.line_names]
|
|
374
|
+
if diff_col and df.shape[1] == 2:
|
|
375
|
+
col0, col1 = df.columns.tolist()
|
|
376
|
+
df['Diff'] = df[col1] - df[col0]
|
|
377
|
+
|
|
378
|
+
if print_data:
|
|
379
|
+
print(bench.plot_name + ':')
|
|
380
|
+
print(df.to_string())
|
|
381
|
+
if save_path:
|
|
382
|
+
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
|
|
383
|
+
index=False)
|
|
384
|
+
return df
|
|
385
|
+
|
|
386
|
+
def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
|
|
387
|
+
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
|
388
|
+
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
|
389
|
+
result_dfs = []
|
|
390
|
+
try:
|
|
391
|
+
for bench in benchmarks:
|
|
392
|
+
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
|
|
393
|
+
finally:
|
|
394
|
+
if save_path:
|
|
395
|
+
# Create directory if it doesn't exist
|
|
396
|
+
os.makedirs(save_path, exist_ok=True)
|
|
397
|
+
with open(os.path.join(save_path, "results.html"), "w") as html:
|
|
398
|
+
html.write("<html><body>\n")
|
|
399
|
+
for bench in benchmarks[:len(result_dfs)]:
|
|
400
|
+
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
|
401
|
+
html.write("</body></html>\n")
|
|
402
|
+
if return_df:
|
|
403
|
+
if has_single_bench:
|
|
404
|
+
return result_dfs[0]
|
|
405
|
+
else:
|
|
406
|
+
return result_dfs
|
|
407
|
+
return None
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def perf_report(benchmarks):
|
|
411
|
+
"""
|
|
412
|
+
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
|
|
413
|
+
|
|
414
|
+
:param benchmarks: Benchmarking configurations.
|
|
415
|
+
:type benchmarks: List of :class:`Benchmark`
|
|
416
|
+
"""
|
|
417
|
+
wrapper = lambda fn: Mark(fn, benchmarks)
|
|
418
|
+
return wrapper
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_dram_gbps(device=None):
|
|
422
|
+
''' return DRAM bandwidth in GB/s '''
|
|
423
|
+
import torch
|
|
424
|
+
|
|
425
|
+
from .runtime import driver
|
|
426
|
+
if not device:
|
|
427
|
+
device = torch.cuda.current_device()
|
|
428
|
+
mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
|
429
|
+
bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"]
|
|
430
|
+
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
|
431
|
+
return bw_gbps
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def get_max_tensorcore_tflops(dtype, clock_rate, device=None):
|
|
435
|
+
import torch
|
|
436
|
+
|
|
437
|
+
from .runtime import driver
|
|
438
|
+
if not device:
|
|
439
|
+
device = torch.cuda.current_device()
|
|
440
|
+
|
|
441
|
+
num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
|
|
442
|
+
capability = torch.cuda.get_device_capability(device)
|
|
443
|
+
if capability[0] < 8:
|
|
444
|
+
assert dtype == torch.float16
|
|
445
|
+
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
|
446
|
+
else:
|
|
447
|
+
if dtype in [torch.float32, torch.int32]:
|
|
448
|
+
ops_per_sub_core = 256
|
|
449
|
+
elif dtype in [torch.float16, torch.bfloat16, torch.int16]:
|
|
450
|
+
ops_per_sub_core = 512
|
|
451
|
+
elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
|
|
452
|
+
ops_per_sub_core = 1024
|
|
453
|
+
else:
|
|
454
|
+
raise RuntimeError("dtype not supported")
|
|
455
|
+
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
|
456
|
+
return tflops
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# create decorator that wraps test function into
|
|
460
|
+
# a cuda-memcheck system call
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def cuda_memcheck(**target_kwargs):
|
|
464
|
+
|
|
465
|
+
def decorator(test_fn):
|
|
466
|
+
|
|
467
|
+
@functools.wraps(test_fn)
|
|
468
|
+
def wrapper(*args, **kwargs):
|
|
469
|
+
import psutil
|
|
470
|
+
ppid_name = psutil.Process(os.getppid()).name()
|
|
471
|
+
run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
|
|
472
|
+
if run_cuda_memcheck and ppid_name != "cuda-memcheck":
|
|
473
|
+
path = os.path.realpath(test_fn.__globals__["__file__"])
|
|
474
|
+
# get path of current file
|
|
475
|
+
env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
|
|
476
|
+
assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
|
|
477
|
+
test_id = kwargs['request'].node.callspec.id
|
|
478
|
+
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
|
|
479
|
+
out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
|
|
480
|
+
assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
|
|
481
|
+
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
|
|
482
|
+
else:
|
|
483
|
+
test_fn(*args, **kwargs)
|
|
484
|
+
|
|
485
|
+
return wrapper
|
|
486
|
+
|
|
487
|
+
return decorator
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
@contextmanager
|
|
491
|
+
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
|
492
|
+
try:
|
|
493
|
+
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
|
|
494
|
+
subprocess.check_output([
|
|
495
|
+
"nvidia-smi",
|
|
496
|
+
"-i",
|
|
497
|
+
"0",
|
|
498
|
+
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
|
499
|
+
])
|
|
500
|
+
subprocess.check_output([
|
|
501
|
+
"nvidia-smi",
|
|
502
|
+
"-i",
|
|
503
|
+
"0",
|
|
504
|
+
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
|
505
|
+
])
|
|
506
|
+
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
|
|
507
|
+
cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
|
|
508
|
+
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
|
509
|
+
assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
|
|
510
|
+
tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
|
|
511
|
+
gbps = 640 * 2 * ref_mem_clock * 1e-3
|
|
512
|
+
yield tflops, gbps
|
|
513
|
+
finally:
|
|
514
|
+
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
|
|
515
|
+
subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
|
|
516
|
+
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def get_max_simd_tflops(dtype, clock_rate, device=None):
|
|
520
|
+
import torch
|
|
521
|
+
|
|
522
|
+
from .runtime import driver
|
|
523
|
+
if not device:
|
|
524
|
+
device = torch.cuda.current_device()
|
|
525
|
+
|
|
526
|
+
num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
|
|
527
|
+
capability = torch.cuda.get_device_capability()
|
|
528
|
+
if capability[0] < 8:
|
|
529
|
+
if dtype == torch.float32:
|
|
530
|
+
ops_per_sub_core = 32 # 2*16
|
|
531
|
+
elif dtype == torch.float16:
|
|
532
|
+
ops_per_sub_core = 64
|
|
533
|
+
else:
|
|
534
|
+
raise RuntimeError("dtype not supported")
|
|
535
|
+
else:
|
|
536
|
+
if dtype == torch.float32:
|
|
537
|
+
ops_per_sub_core = 32
|
|
538
|
+
elif dtype in [torch.float16, torch.bfloat16]:
|
|
539
|
+
ops_per_sub_core = 64
|
|
540
|
+
else:
|
|
541
|
+
raise RuntimeError("dtype not supported")
|
|
542
|
+
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
|
543
|
+
return tflops
|
triton/tools/__init__.py
ADDED
|
File without changes
|