triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
triton/testing.py ADDED
@@ -0,0 +1,543 @@
1
+ import functools
2
+ import math
3
+ import os
4
+ import statistics
5
+ import subprocess
6
+ import sys
7
+ from contextlib import contextmanager
8
+ from typing import Any, Dict, List
9
+ from . import language as tl
10
+ from . import runtime
11
+
12
+
13
+ def nvsmi(attrs):
14
+ attrs = ','.join(attrs)
15
+ cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
16
+ out = subprocess.check_output(cmd)
17
+ ret = out.decode(sys.stdout.encoding).split(',')
18
+ ret = [int(x) for x in ret]
19
+ return ret
20
+
21
+
22
+ # pure Python implementation of np.quantile/torch.quantile
23
+ # to avoid unnecessary runtime dependency on numpy/torch
24
+
25
+
26
+ def _quantile(a, q):
27
+ n = len(a)
28
+ a = sorted(a)
29
+
30
+ def get_quantile(q):
31
+ if not (0 <= q <= 1):
32
+ raise ValueError("Quantiles must be in the range [0, 1]")
33
+ point = q * (n - 1)
34
+ lower = math.floor(point)
35
+ upper = math.ceil(point)
36
+ t = point - lower
37
+ return (1 - t) * a[lower] + t * a[upper]
38
+
39
+ return [get_quantile(q) for q in q]
40
+
41
+
42
+ def _summarize_statistics(times, quantiles, return_mode):
43
+ if quantiles is not None:
44
+ ret = _quantile(times, quantiles)
45
+ if len(ret) == 1:
46
+ ret = ret[0]
47
+ return ret
48
+ if return_mode == "all":
49
+ return times
50
+ elif return_mode == "min":
51
+ return min(times)
52
+ elif return_mode == "max":
53
+ return max(times)
54
+ elif return_mode == "mean":
55
+ return statistics.mean(times)
56
+ elif return_mode == "median":
57
+ return statistics.median(times)
58
+
59
+
60
+ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
61
+ """
62
+ Benchmark the runtime of the provided function.
63
+
64
+ :param fn: Function to benchmark
65
+ :type fn: Callable
66
+ :param rep: Repetition time (in ms)
67
+ :type rep: int
68
+ :param grad_to_none: Reset the gradient of the provided tensor to None
69
+ :type grad_to_none: torch.tensor, optional
70
+ :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
71
+ :type return_mode: str
72
+ """
73
+ import torch
74
+ assert return_mode in ["min", "max", "mean", "median", "all"]
75
+
76
+ with torch.cuda.stream(torch.cuda.Stream()):
77
+ # warmup
78
+ fn()
79
+ if grad_to_none is not None:
80
+ for x in grad_to_none:
81
+ x.detach_()
82
+ x.requires_grad_(True)
83
+ x.grad = None
84
+ # step 1 - we estimate the amount of time the kernel call takes
85
+ # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
86
+ # but it is probably good enough
87
+ # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive,
88
+ # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2
89
+ # cache flush).
90
+ start_event = torch.cuda.Event(enable_timing=True)
91
+ end_event = torch.cuda.Event(enable_timing=True)
92
+ start_event.record()
93
+ for _ in range(5):
94
+ fn()
95
+ end_event.record()
96
+ torch.cuda.synchronize()
97
+ estimate_ms = start_event.elapsed_time(end_event) / 5
98
+ # Rewrite to avoid possible division by 0 issues with fast benchmarks
99
+ if estimate_ms == 0:
100
+ n_repeat = 1000
101
+ else:
102
+ n_repeat = max(1, int(rep / estimate_ms))
103
+ # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
104
+ # host overhead
105
+ g = torch.cuda.CUDAGraph()
106
+ with torch.cuda.graph(g):
107
+ for _ in range(n_repeat):
108
+ if grad_to_none is not None:
109
+ for x in grad_to_none:
110
+ x.grad = None
111
+ fn()
112
+ torch.cuda.synchronize()
113
+ # measure time and return
114
+ ret = []
115
+ n_retries = 10
116
+ for _ in range(n_retries):
117
+ start_event = torch.cuda.Event(enable_timing=True)
118
+ end_event = torch.cuda.Event(enable_timing=True)
119
+ start_event.record()
120
+ g.replay()
121
+ end_event.record()
122
+ torch.cuda.synchronize()
123
+ ret += [start_event.elapsed_time(end_event) / n_repeat]
124
+ return _summarize_statistics(ret, quantiles, return_mode)
125
+
126
+
127
+ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
128
+ """
129
+ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
130
+ the 20-th and 80-th performance percentile.
131
+
132
+ :param fn: Function to benchmark
133
+ :type fn: Callable
134
+ :param warmup: Warmup time (in ms)
135
+ :type warmup: int
136
+ :param rep: Repetition time (in ms)
137
+ :type rep: int
138
+ :param grad_to_none: Reset the gradient of the provided tensor to None
139
+ :type grad_to_none: torch.tensor, optional
140
+ :param quantiles: Performance percentile to return in addition to the median.
141
+ :type quantiles: list[float], optional
142
+ :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
143
+ :type return_mode: str
144
+ """
145
+ assert return_mode in ["min", "max", "mean", "median", "all"]
146
+
147
+ di = runtime.driver.active.get_device_interface()
148
+
149
+ fn()
150
+ di.synchronize()
151
+
152
+ cache = runtime.driver.active.get_empty_cache_for_benchmark()
153
+
154
+ # Estimate the runtime of the function
155
+ start_event = di.Event(enable_timing=True)
156
+ end_event = di.Event(enable_timing=True)
157
+ start_event.record()
158
+ for _ in range(5):
159
+ runtime.driver.active.clear_cache(cache)
160
+ fn()
161
+ end_event.record()
162
+ di.synchronize()
163
+ estimate_ms = start_event.elapsed_time(end_event) / 5
164
+
165
+ # compute number of warmup and repeat
166
+ n_warmup = max(1, int(warmup / estimate_ms))
167
+ n_repeat = max(1, int(rep / estimate_ms))
168
+ start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
169
+ end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
170
+ # Warm-up
171
+ for _ in range(n_warmup):
172
+ fn()
173
+ # Benchmark
174
+ for i in range(n_repeat):
175
+ # we don't want `fn` to accumulate gradient values
176
+ # if it contains a backward pass. So we clear the
177
+ # provided gradients
178
+ if grad_to_none is not None:
179
+ for x in grad_to_none:
180
+ x.grad = None
181
+ # we clear the L2 cache before each run
182
+ runtime.driver.active.clear_cache(cache)
183
+ # record time of `fn`
184
+ start_event[i].record()
185
+ fn()
186
+ end_event[i].record()
187
+ # Record clocks
188
+ di.synchronize()
189
+ times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
190
+ return _summarize_statistics(times, quantiles, return_mode)
191
+
192
+
193
+ def assert_close(x, y, atol=None, rtol=None, err_msg=''):
194
+ """
195
+ Asserts that two inputs are close within a certain tolerance.
196
+
197
+ :param x: The first input.
198
+ :type x: scala, list, numpy.ndarray, or torch.Tensor
199
+ :param y: The second input.
200
+ :type y: scala, list, numpy.ndarray, or torch.Tensor
201
+ :param atol: The absolute tolerance. Default value is 1e-2.
202
+ :type atol: float, optional
203
+ :param rtol: The relative tolerance. Default value is 0.
204
+ :type rtol: float, optional
205
+ :param err_msg: The error message to use if the assertion fails.
206
+ :type err_msg: str
207
+ """
208
+ import numpy as np
209
+ import torch
210
+
211
+ # canonicalize arguments to be tensors
212
+ if not isinstance(x, torch.Tensor):
213
+ x = torch.tensor(x)
214
+ if not isinstance(y, torch.Tensor):
215
+ y = torch.tensor(y)
216
+ # absolute tolerance
217
+ if atol is None:
218
+ atol = 1e-2
219
+ atol = atol(x.dtype) if callable(atol) else atol
220
+ # relative tolerance hook
221
+ if rtol is None:
222
+ rtol = 0.
223
+ rtol = rtol(x.dtype) if callable(rtol) else rtol
224
+ # we use numpy instead of pytorch
225
+ # as it seems more memory efficient
226
+ # pytorch tends to oom on large tensors
227
+ if isinstance(x, torch.Tensor):
228
+ if x.dtype == torch.bfloat16:
229
+ x = x.float()
230
+ x = x.cpu().detach().numpy()
231
+ if isinstance(y, torch.Tensor):
232
+ if y.dtype == torch.bfloat16:
233
+ y = y.float()
234
+ y = y.cpu().detach().numpy()
235
+ # we handle size==1 case separately as we can
236
+ # provide better error message there
237
+ if x.size > 1 or y.size > 1:
238
+ np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
239
+ return
240
+ if not np.allclose(x, y, atol=atol, rtol=rtol):
241
+ raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
242
+
243
+
244
+ class Benchmark:
245
+ """
246
+ This class is used by the :code:`perf_report` function to generate line plots with a concise API.
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ x_names: List[str],
252
+ x_vals: List[Any],
253
+ line_arg: str,
254
+ line_vals: List[Any],
255
+ line_names: List[str],
256
+ plot_name: str,
257
+ args: Dict[str, Any],
258
+ xlabel: str = '',
259
+ ylabel: str = '',
260
+ x_log: bool = False,
261
+ y_log: bool = False,
262
+ styles=None,
263
+ ):
264
+ """
265
+ Constructor.
266
+ x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
267
+ of scalars and there are multiple x_names, all arguments will have the same value.
268
+ If x_vals is a list of tuples/lists, each element should have the same length as
269
+ x_names.
270
+
271
+ :param x_names: Name of the arguments that should appear on the x axis of the plot.
272
+ :type x_names: List[str]
273
+ :param x_vals: List of values to use for the arguments in :code:`x_names`.
274
+ :type x_vals: List[Any]
275
+ :param line_arg: Argument name for which different values correspond to different lines in the plot.
276
+ :type line_arg: str
277
+ :param line_vals: List of values to use for the arguments in :code:`line_arg`.
278
+ :type line_vals: List[Any]
279
+ :param line_names: Label names for the different lines.
280
+ :type line_names: List[str]
281
+ :param plot_name: Name of the plot.
282
+ :type plot_name: str
283
+ :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
284
+ :type args: Dict[str, Any]
285
+ :param xlabel: Label for the x axis of the plot.
286
+ :type xlabel: str, optional
287
+ :param ylabel: Label for the y axis of the plot.
288
+ :type ylabel: str, optional
289
+ :param x_log: Whether the x axis should be log scale.
290
+ :type x_log: bool, optional
291
+ :param y_log: Whether the y axis should be log scale.
292
+ :type y_log: bool, optional
293
+ :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle.
294
+ :type styles: list[tuple[str, str]]
295
+ """
296
+ self.x_names = x_names
297
+ self.x_vals = x_vals
298
+ self.x_log = x_log
299
+ self.line_arg = line_arg
300
+ self.line_vals = line_vals
301
+ self.line_names = line_names
302
+ self.y_log = y_log
303
+ self.styles = styles
304
+ # plot info
305
+ self.xlabel = xlabel
306
+ self.ylabel = ylabel
307
+ self.plot_name = plot_name
308
+ self.args = args
309
+
310
+
311
+ class Mark:
312
+
313
+ def __init__(self, fn, benchmarks):
314
+ self.fn = fn
315
+ self.benchmarks = benchmarks
316
+
317
+ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
318
+ save_precision=6, **kwrags):
319
+ import os
320
+
321
+ import matplotlib.pyplot as plt
322
+ import pandas as pd
323
+ y_mean = bench.line_names
324
+ y_min = [f'{x}-min' for x in bench.line_names]
325
+ y_max = [f'{x}-max' for x in bench.line_names]
326
+ x_names = list(bench.x_names)
327
+ df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
328
+ for x in bench.x_vals:
329
+ # x can be a single value or a sequence of values.
330
+ if not isinstance(x, (list, tuple)):
331
+ x = [x for _ in x_names]
332
+
333
+ if len(x) != len(x_names):
334
+ raise ValueError(f"Expected {len(x_names)} values, got {x}")
335
+ x_args = dict(zip(x_names, x))
336
+
337
+ row_mean, row_min, row_max = [], [], []
338
+ for y in bench.line_vals:
339
+ ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
340
+ try:
341
+ y_mean, y_min, y_max = ret
342
+ except TypeError:
343
+ y_mean, y_min, y_max = ret, None, None
344
+ row_mean += [y_mean]
345
+ row_min += [y_min]
346
+ row_max += [y_max]
347
+ df.loc[len(df)] = list(x) + row_mean + row_min + row_max
348
+
349
+ if bench.plot_name:
350
+ plt.figure()
351
+ ax = plt.subplot()
352
+ # Plot first x value on x axis if there are multiple.
353
+ first_x = x_names[0]
354
+ for i, y in enumerate(bench.line_names):
355
+ y_min, y_max = df[y + '-min'], df[y + '-max']
356
+ col = bench.styles[i][0] if bench.styles else None
357
+ sty = bench.styles[i][1] if bench.styles else None
358
+ ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
359
+ if not y_min.isnull().all() and not y_max.isnull().all():
360
+ y_min = y_min.astype(float)
361
+ y_max = y_max.astype(float)
362
+ ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
363
+ ax.legend()
364
+ ax.set_xlabel(bench.xlabel or first_x)
365
+ ax.set_ylabel(bench.ylabel)
366
+ # ax.set_title(bench.plot_name)
367
+ ax.set_xscale("log" if bench.x_log else "linear")
368
+ ax.set_yscale("log" if bench.y_log else "linear")
369
+ if show_plots:
370
+ plt.show()
371
+ if save_path:
372
+ plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
373
+ df = df[x_names + bench.line_names]
374
+ if diff_col and df.shape[1] == 2:
375
+ col0, col1 = df.columns.tolist()
376
+ df['Diff'] = df[col1] - df[col0]
377
+
378
+ if print_data:
379
+ print(bench.plot_name + ':')
380
+ print(df.to_string())
381
+ if save_path:
382
+ df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
383
+ index=False)
384
+ return df
385
+
386
+ def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
387
+ has_single_bench = isinstance(self.benchmarks, Benchmark)
388
+ benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
389
+ result_dfs = []
390
+ try:
391
+ for bench in benchmarks:
392
+ result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
393
+ finally:
394
+ if save_path:
395
+ # Create directory if it doesn't exist
396
+ os.makedirs(save_path, exist_ok=True)
397
+ with open(os.path.join(save_path, "results.html"), "w") as html:
398
+ html.write("<html><body>\n")
399
+ for bench in benchmarks[:len(result_dfs)]:
400
+ html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
401
+ html.write("</body></html>\n")
402
+ if return_df:
403
+ if has_single_bench:
404
+ return result_dfs[0]
405
+ else:
406
+ return result_dfs
407
+ return None
408
+
409
+
410
+ def perf_report(benchmarks):
411
+ """
412
+ Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
413
+
414
+ :param benchmarks: Benchmarking configurations.
415
+ :type benchmarks: List of :class:`Benchmark`
416
+ """
417
+ wrapper = lambda fn: Mark(fn, benchmarks)
418
+ return wrapper
419
+
420
+
421
+ def get_dram_gbps(device=None):
422
+ ''' return DRAM bandwidth in GB/s '''
423
+ import torch
424
+
425
+ from .runtime import driver
426
+ if not device:
427
+ device = torch.cuda.current_device()
428
+ mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
429
+ bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"]
430
+ bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
431
+ return bw_gbps
432
+
433
+
434
+ def get_max_tensorcore_tflops(dtype, clock_rate, device=None):
435
+ import torch
436
+
437
+ from .runtime import driver
438
+ if not device:
439
+ device = torch.cuda.current_device()
440
+
441
+ num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
442
+ capability = torch.cuda.get_device_capability(device)
443
+ if capability[0] < 8:
444
+ assert dtype == torch.float16
445
+ ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
446
+ else:
447
+ if dtype in [torch.float32, torch.int32]:
448
+ ops_per_sub_core = 256
449
+ elif dtype in [torch.float16, torch.bfloat16, torch.int16]:
450
+ ops_per_sub_core = 512
451
+ elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
452
+ ops_per_sub_core = 1024
453
+ else:
454
+ raise RuntimeError("dtype not supported")
455
+ tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
456
+ return tflops
457
+
458
+
459
+ # create decorator that wraps test function into
460
+ # a cuda-memcheck system call
461
+
462
+
463
+ def cuda_memcheck(**target_kwargs):
464
+
465
+ def decorator(test_fn):
466
+
467
+ @functools.wraps(test_fn)
468
+ def wrapper(*args, **kwargs):
469
+ import psutil
470
+ ppid_name = psutil.Process(os.getppid()).name()
471
+ run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
472
+ if run_cuda_memcheck and ppid_name != "cuda-memcheck":
473
+ path = os.path.realpath(test_fn.__globals__["__file__"])
474
+ # get path of current file
475
+ env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
476
+ assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
477
+ test_id = kwargs['request'].node.callspec.id
478
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
479
+ out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
480
+ assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
481
+ assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
482
+ else:
483
+ test_fn(*args, **kwargs)
484
+
485
+ return wrapper
486
+
487
+ return decorator
488
+
489
+
490
+ @contextmanager
491
+ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
492
+ try:
493
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
494
+ subprocess.check_output([
495
+ "nvidia-smi",
496
+ "-i",
497
+ "0",
498
+ f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
499
+ ])
500
+ subprocess.check_output([
501
+ "nvidia-smi",
502
+ "-i",
503
+ "0",
504
+ f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
505
+ ])
506
+ cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
507
+ cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
508
+ assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
509
+ assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
510
+ tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
511
+ gbps = 640 * 2 * ref_mem_clock * 1e-3
512
+ yield tflops, gbps
513
+ finally:
514
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
515
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
516
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
517
+
518
+
519
+ def get_max_simd_tflops(dtype, clock_rate, device=None):
520
+ import torch
521
+
522
+ from .runtime import driver
523
+ if not device:
524
+ device = torch.cuda.current_device()
525
+
526
+ num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
527
+ capability = torch.cuda.get_device_capability()
528
+ if capability[0] < 8:
529
+ if dtype == torch.float32:
530
+ ops_per_sub_core = 32 # 2*16
531
+ elif dtype == torch.float16:
532
+ ops_per_sub_core = 64
533
+ else:
534
+ raise RuntimeError("dtype not supported")
535
+ else:
536
+ if dtype == torch.float32:
537
+ ops_per_sub_core = 32
538
+ elif dtype in [torch.float16, torch.bfloat16]:
539
+ ops_per_sub_core = 64
540
+ else:
541
+ raise RuntimeError("dtype not supported")
542
+ tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
543
+ return tflops
File without changes