triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
triton/testing.py ADDED
@@ -0,0 +1,496 @@
1
+ import functools
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, List
7
+ from . import language as tl
8
+
9
+
10
+ def nvsmi(attrs):
11
+ attrs = ','.join(attrs)
12
+ cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
13
+ out = subprocess.check_output(cmd)
14
+ ret = out.decode(sys.stdout.encoding).split(',')
15
+ ret = [int(x) for x in ret]
16
+ return ret
17
+
18
+
19
+ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"):
20
+ """
21
+ Benchmark the runtime of the provided function.
22
+
23
+ :param fn: Function to benchmark
24
+ :type fn: Callable
25
+ :param rep: Repetition time (in ms)
26
+ :type rep: int
27
+ :param grad_to_none: Reset the gradient of the provided tensor to None
28
+ :type grad_to_none: torch.tensor, optional
29
+ """
30
+ import torch
31
+ assert return_mode in ["min", "max", "mean", "median"]
32
+
33
+ if torch.cuda.current_stream() == torch.cuda.default_stream():
34
+ raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.")
35
+ # warmup
36
+ fn()
37
+ # step 1 - we estimate the amount of time the kernel call takes
38
+ # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
39
+ # but it is probably good enough
40
+ if grad_to_none is not None:
41
+ for x in grad_to_none:
42
+ x.detach_()
43
+ x.requires_grad_(True)
44
+ x.grad = None
45
+ g = torch.cuda.CUDAGraph()
46
+ with torch.cuda.graph(g):
47
+ fn()
48
+ torch.cuda.synchronize()
49
+ start_event = torch.cuda.Event(enable_timing=True)
50
+ end_event = torch.cuda.Event(enable_timing=True)
51
+ start_event.record()
52
+ g.replay()
53
+ end_event.record()
54
+ torch.cuda.synchronize()
55
+ estimate_ms = start_event.elapsed_time(end_event)
56
+ n_repeat = max(1, int(rep / estimate_ms))
57
+ # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
58
+ # host overhead
59
+ g = torch.cuda.CUDAGraph()
60
+ with torch.cuda.graph(g):
61
+ for i in range(n_repeat):
62
+ if grad_to_none is not None:
63
+ for x in grad_to_none:
64
+ x.grad = None
65
+ fn()
66
+ torch.cuda.synchronize()
67
+ # measure time and return
68
+ ret = []
69
+ n_retries = 10
70
+ for i in range(n_retries):
71
+ start_event = torch.cuda.Event(enable_timing=True)
72
+ end_event = torch.cuda.Event(enable_timing=True)
73
+ start_event.record()
74
+ g.replay()
75
+ end_event.record()
76
+ torch.cuda.synchronize()
77
+ ret += [start_event.elapsed_time(end_event) / n_repeat]
78
+ times = torch.tensor(ret)
79
+ return getattr(torch, return_mode)(times).item()
80
+
81
+
82
+ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
83
+ device_type="cuda"):
84
+ """
85
+ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
86
+ the 20-th and 80-th performance percentile.
87
+
88
+ :param fn: Function to benchmark
89
+ :type fn: Callable
90
+ :param warmup: Warmup time (in ms)
91
+ :type warmup: int
92
+ :param rep: Repetition time (in ms)
93
+ :type rep: int
94
+ :param grad_to_none: Reset the gradient of the provided tensor to None
95
+ :type grad_to_none: torch.tensor, optional
96
+ :param quantiles: Performance percentile to return in addition to the median.
97
+ :type quantiles: list[float]
98
+ :param fast_flush: Use faster kernel to flush L2 between measurements
99
+ :type fast_flush: bool
100
+ """
101
+ assert return_mode in ["min", "max", "mean", "median"]
102
+ import torch
103
+
104
+ di = torch._dynamo.device_interface.get_interface_for_device(device_type)
105
+
106
+ fn()
107
+ di.synchronize()
108
+
109
+ # We maintain a buffer of 256 MB that we clear
110
+ # before each kernel call to make sure that the L2
111
+ # doesn't contain any input data before the run
112
+ if fast_flush:
113
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type)
114
+ else:
115
+ cache = torch.empty(int(256e6), dtype=torch.int8, device=device_type)
116
+
117
+ # Estimate the runtime of the function
118
+ start_event = di.Event(enable_timing=True)
119
+ end_event = di.Event(enable_timing=True)
120
+ start_event.record()
121
+ for _ in range(5):
122
+ cache.zero_()
123
+ fn()
124
+ end_event.record()
125
+ di.synchronize()
126
+ estimate_ms = start_event.elapsed_time(end_event) / 5
127
+
128
+ # compute number of warmup and repeat
129
+ n_warmup = max(1, int(warmup / estimate_ms))
130
+ n_repeat = max(1, int(rep / estimate_ms))
131
+ start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
132
+ end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
133
+ # Warm-up
134
+ for _ in range(n_warmup):
135
+ fn()
136
+ # Benchmark
137
+ for i in range(n_repeat):
138
+ # we don't want `fn` to accumulate gradient values
139
+ # if it contains a backward pass. So we clear the
140
+ # provided gradients
141
+ if grad_to_none is not None:
142
+ for x in grad_to_none:
143
+ x.grad = None
144
+ # we clear the L2 cache before each run
145
+ cache.zero_()
146
+ # record time of `fn`
147
+ start_event[i].record()
148
+ fn()
149
+ end_event[i].record()
150
+ # Record clocks
151
+ di.synchronize()
152
+ times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
153
+ if quantiles is not None:
154
+ ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
155
+ if len(ret) == 1:
156
+ ret = ret[0]
157
+ return ret
158
+ return getattr(torch, return_mode)(times).item()
159
+
160
+
161
+ def assert_close(x, y, atol=None, rtol=None, err_msg=''):
162
+ import numpy as np
163
+ import torch
164
+
165
+ # canonicalize arguments to be tensors
166
+ if not isinstance(x, torch.Tensor):
167
+ x = torch.tensor(x)
168
+ if not isinstance(y, torch.Tensor):
169
+ y = torch.tensor(y)
170
+ # absolute tolerance
171
+ if atol is None:
172
+ atol = 1e-2
173
+ atol = atol(x.dtype) if callable(atol) else atol
174
+ # relative tolerance hook
175
+ if rtol is None:
176
+ rtol = 0.
177
+ rtol = rtol(x.dtype) if callable(rtol) else rtol
178
+ # we use numpy instead of pytorch
179
+ # as it seems more memory efficient
180
+ # pytorch tends to oom on large tensors
181
+ if isinstance(x, torch.Tensor):
182
+ if x.dtype == torch.bfloat16:
183
+ x = x.float()
184
+ x = x.cpu().detach().numpy()
185
+ if isinstance(y, torch.Tensor):
186
+ if y.dtype == torch.bfloat16:
187
+ y = y.float()
188
+ y = y.cpu().detach().numpy()
189
+ # we handle size==1 case separately as we can
190
+ # provide better error message there
191
+ if x.size > 1 or y.size > 1:
192
+ np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
193
+ return
194
+ if not np.allclose(x, y, atol=atol, rtol=rtol):
195
+ raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
196
+
197
+
198
+ class Benchmark:
199
+ """
200
+ This class is used by the :code:`perf_report` function to generate line plots with a concise API.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ x_names: List[str],
206
+ x_vals: List[Any],
207
+ line_arg: str,
208
+ line_vals: List[Any],
209
+ line_names: List[str],
210
+ plot_name: str,
211
+ args: Dict[str, Any],
212
+ xlabel: str = '',
213
+ ylabel: str = '',
214
+ x_log: bool = False,
215
+ y_log: bool = False,
216
+ color=None,
217
+ styles=None,
218
+ ):
219
+ """
220
+ Constructor.
221
+ x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
222
+ of scalars and there are multiple x_names, all arguments will have the same value.
223
+ If x_vals is a list of tuples/lists, each element should have the same length as
224
+ x_names.
225
+
226
+ :param x_names: Name of the arguments that should appear on the x axis of the plot.
227
+ :type x_names: List[str]
228
+ :param x_vals: List of values to use for the arguments in :code:`x_names`.
229
+ :type x_vals: List[Any]
230
+ :param line_arg: Argument name for which different values correspond to different lines in the plot.
231
+ :type line_arg: str
232
+ :param line_vals: List of values to use for the arguments in :code:`line_arg`.
233
+ :type line_vals: List[Any]
234
+ :param line_names: Label names for the different lines.
235
+ :type line_names: List[str]
236
+ :param plot_name: Name of the plot.
237
+ :type plot_name: str
238
+ :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
239
+ :type args: Dict[str, Any]
240
+ :param xlabel: Label for the x axis of the plot.
241
+ :type xlabel: str, optional
242
+ :param ylabel: Label for the y axis of the plot.
243
+ :type ylabel: str, optional
244
+ :param x_log: Whether the x axis should be log scale.
245
+ :type x_log: bool, optional
246
+ :param y_log: Whether the y axis should be log scale.
247
+ :type y_log: bool, optional
248
+ """
249
+ self.x_names = x_names
250
+ self.x_vals = x_vals
251
+ self.x_log = x_log
252
+ self.line_arg = line_arg
253
+ self.line_vals = line_vals
254
+ self.line_names = line_names
255
+ self.y_log = y_log
256
+ self.styles = styles
257
+ # plot info
258
+ self.xlabel = xlabel
259
+ self.ylabel = ylabel
260
+ self.plot_name = plot_name
261
+ self.args = args
262
+
263
+
264
+ class Mark:
265
+
266
+ def __init__(self, fn, benchmarks):
267
+ self.fn = fn
268
+ self.benchmarks = benchmarks
269
+
270
+ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
271
+ save_precision=6, **kwrags):
272
+ import os
273
+
274
+ import matplotlib.pyplot as plt
275
+ import pandas as pd
276
+ y_mean = bench.line_names
277
+ y_min = [f'{x}-min' for x in bench.line_names]
278
+ y_max = [f'{x}-max' for x in bench.line_names]
279
+ x_names = list(bench.x_names)
280
+ df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
281
+ for x in bench.x_vals:
282
+ # x can be a single value or a sequence of values.
283
+ if not isinstance(x, (list, tuple)):
284
+ x = [x for _ in x_names]
285
+
286
+ if len(x) != len(x_names):
287
+ raise ValueError(f"Expected {len(x_names)} values, got {x}")
288
+ x_args = dict(zip(x_names, x))
289
+
290
+ row_mean, row_min, row_max = [], [], []
291
+ for y in bench.line_vals:
292
+ ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
293
+ try:
294
+ y_mean, y_min, y_max = ret
295
+ except TypeError:
296
+ y_mean, y_min, y_max = ret, None, None
297
+ row_mean += [y_mean]
298
+ row_min += [y_min]
299
+ row_max += [y_max]
300
+ df.loc[len(df)] = list(x) + row_mean + row_min + row_max
301
+
302
+ if bench.plot_name:
303
+ plt.figure()
304
+ ax = plt.subplot()
305
+ # Plot first x value on x axis if there are multiple.
306
+ first_x = x_names[0]
307
+ for i, y in enumerate(bench.line_names):
308
+ y_min, y_max = df[y + '-min'], df[y + '-max']
309
+ col = bench.styles[i][0] if bench.styles else None
310
+ sty = bench.styles[i][1] if bench.styles else None
311
+ ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
312
+ if not y_min.isnull().all() and not y_max.isnull().all():
313
+ y_min = y_min.astype(float)
314
+ y_max = y_max.astype(float)
315
+ ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
316
+ ax.legend()
317
+ ax.set_xlabel(bench.xlabel or first_x)
318
+ ax.set_ylabel(bench.ylabel)
319
+ # ax.set_title(bench.plot_name)
320
+ ax.set_xscale("log" if bench.x_log else "linear")
321
+ ax.set_yscale("log" if bench.y_log else "linear")
322
+ if show_plots:
323
+ plt.show()
324
+ if save_path:
325
+ plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
326
+ df = df[x_names + bench.line_names]
327
+ if diff_col and df.shape[1] == 2:
328
+ col0, col1 = df.columns.tolist()
329
+ df['Diff'] = df[col1] - df[col0]
330
+
331
+ if print_data:
332
+ print(bench.plot_name + ':')
333
+ print(df.to_string())
334
+ if save_path:
335
+ df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
336
+ index=False)
337
+ return df
338
+
339
+ def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
340
+ has_single_bench = isinstance(self.benchmarks, Benchmark)
341
+ benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
342
+ result_dfs = []
343
+ if save_path:
344
+ # Create directory if it doesn't exist
345
+ os.makedirs(save_path, exist_ok=True)
346
+ html = open(os.path.join(save_path, "results.html"), "w")
347
+ html.write("<html><body>\n")
348
+ for bench in benchmarks:
349
+ result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
350
+ if save_path:
351
+ html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
352
+ if save_path:
353
+ html.write("</body></html>\n")
354
+ html.close()
355
+ if return_df:
356
+ if has_single_bench:
357
+ return result_dfs[0]
358
+ else:
359
+ return result_dfs
360
+ return None
361
+
362
+
363
+ def perf_report(benchmarks):
364
+ """
365
+ Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
366
+
367
+ :param benchmarks: Benchmarking configurations.
368
+ :type benchmarks: List of :class:`Benchmark`
369
+ """
370
+ wrapper = lambda fn: Mark(fn, benchmarks)
371
+ return wrapper
372
+
373
+
374
+ def get_dram_gbps(device=None):
375
+ ''' return DRAM bandwidth in GB/s '''
376
+ import torch
377
+
378
+ from .runtime import driver
379
+ if not device:
380
+ device = torch.cuda.current_device()
381
+ mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
382
+ bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"]
383
+ bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
384
+ return bw_gbps
385
+
386
+
387
+ def get_max_tensorcore_tflops(dtype, clock_rate, device=None):
388
+ import torch
389
+
390
+ from .runtime import driver
391
+ if not device:
392
+ device = torch.cuda.current_device()
393
+
394
+ num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
395
+ capability = torch.cuda.get_device_capability(device)
396
+ if capability[0] < 8:
397
+ assert dtype == torch.float16
398
+ ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
399
+ else:
400
+ if dtype in [torch.float32, torch.int32]:
401
+ ops_per_sub_core = 256
402
+ elif dtype in [torch.float16, torch.bfloat16, torch.int16]:
403
+ ops_per_sub_core = 512
404
+ elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
405
+ ops_per_sub_core = 1024
406
+ else:
407
+ raise RuntimeError("dtype not supported")
408
+ tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
409
+ return tflops
410
+
411
+
412
+ # create decorator that wraps test function into
413
+ # a cuda-memcheck system call
414
+
415
+
416
+ def cuda_memcheck(**target_kwargs):
417
+
418
+ def decorator(test_fn):
419
+
420
+ @functools.wraps(test_fn)
421
+ def wrapper(*args, **kwargs):
422
+ import psutil
423
+ ppid_name = psutil.Process(os.getppid()).name()
424
+ run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
425
+ if run_cuda_memcheck and ppid_name != "cuda-memcheck":
426
+ path = os.path.realpath(test_fn.__globals__["__file__"])
427
+ # get path of current file
428
+ env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
429
+ assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
430
+ test_id = kwargs['request'].node.callspec.id
431
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
432
+ out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
433
+ assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
434
+ assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
435
+ else:
436
+ test_fn(*args, **kwargs)
437
+
438
+ return wrapper
439
+
440
+ return decorator
441
+
442
+
443
+ @contextmanager
444
+ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
445
+ try:
446
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
447
+ subprocess.check_output([
448
+ "nvidia-smi",
449
+ "-i",
450
+ "0",
451
+ f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
452
+ ])
453
+ subprocess.check_output([
454
+ "nvidia-smi",
455
+ "-i",
456
+ "0",
457
+ f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
458
+ ])
459
+ cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
460
+ cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
461
+ assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
462
+ assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
463
+ tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
464
+ gbps = 640 * 2 * ref_mem_clock * 1e-3
465
+ yield tflops, gbps
466
+ finally:
467
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
468
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
469
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
470
+
471
+
472
+ def get_max_simd_tflops(dtype, clock_rate, device=None):
473
+ import torch
474
+
475
+ from .runtime import driver
476
+ if not device:
477
+ device = torch.cuda.current_device()
478
+
479
+ num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
480
+ capability = torch.cuda.get_device_capability()
481
+ if capability[0] < 8:
482
+ if dtype == torch.float32:
483
+ ops_per_sub_core = 32 # 2*16
484
+ elif dtype == torch.float16:
485
+ ops_per_sub_core = 64
486
+ else:
487
+ raise RuntimeError("dtype not supported")
488
+ else:
489
+ if dtype == torch.float32:
490
+ ops_per_sub_core = 32
491
+ elif dtype in [torch.float16, torch.bfloat16]:
492
+ ops_per_sub_core = 64
493
+ else:
494
+ raise RuntimeError("dtype not supported")
495
+ tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
496
+ return tflops
File without changes