triton-windows 3.2.0.post11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
triton/testing.py ADDED
@@ -0,0 +1,511 @@
1
+ import functools
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, List
7
+ from . import language as tl
8
+ from . import runtime
9
+
10
+
11
+ def nvsmi(attrs):
12
+ attrs = ','.join(attrs)
13
+ cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
14
+ out = subprocess.check_output(cmd)
15
+ ret = out.decode(sys.stdout.encoding).split(',')
16
+ ret = [int(x) for x in ret]
17
+ return ret
18
+
19
+
20
+ def _summarize_statistics(times, quantiles, return_mode):
21
+ import torch
22
+ if quantiles is not None:
23
+ ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
24
+ if len(ret) == 1:
25
+ ret = ret[0]
26
+ return ret
27
+ if return_mode == "all":
28
+ return times.tolist()
29
+ return getattr(torch, return_mode)(times).item()
30
+
31
+
32
+ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
33
+ """
34
+ Benchmark the runtime of the provided function.
35
+
36
+ :param fn: Function to benchmark
37
+ :type fn: Callable
38
+ :param rep: Repetition time (in ms)
39
+ :type rep: int
40
+ :param grad_to_none: Reset the gradient of the provided tensor to None
41
+ :type grad_to_none: torch.tensor, optional
42
+ :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean".
43
+ :type return_mode: str
44
+ """
45
+ import torch
46
+ assert return_mode in ["min", "max", "mean", "median", "all"]
47
+
48
+ with torch.cuda.stream(torch.cuda.Stream()):
49
+ # warmup
50
+ fn()
51
+ if grad_to_none is not None:
52
+ for x in grad_to_none:
53
+ x.detach_()
54
+ x.requires_grad_(True)
55
+ x.grad = None
56
+ # step 1 - we estimate the amount of time the kernel call takes
57
+ # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
58
+ # but it is probably good enough
59
+ # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive,
60
+ # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2
61
+ # cache flush).
62
+ start_event = torch.cuda.Event(enable_timing=True)
63
+ end_event = torch.cuda.Event(enable_timing=True)
64
+ start_event.record()
65
+ for _ in range(5):
66
+ fn()
67
+ end_event.record()
68
+ torch.cuda.synchronize()
69
+ estimate_ms = start_event.elapsed_time(end_event) / 5
70
+ n_repeat = max(1, int(rep / estimate_ms))
71
+ # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
72
+ # host overhead
73
+ g = torch.cuda.CUDAGraph()
74
+ with torch.cuda.graph(g):
75
+ for _ in range(n_repeat):
76
+ if grad_to_none is not None:
77
+ for x in grad_to_none:
78
+ x.grad = None
79
+ fn()
80
+ torch.cuda.synchronize()
81
+ # measure time and return
82
+ ret = []
83
+ n_retries = 10
84
+ for _ in range(n_retries):
85
+ start_event = torch.cuda.Event(enable_timing=True)
86
+ end_event = torch.cuda.Event(enable_timing=True)
87
+ start_event.record()
88
+ g.replay()
89
+ end_event.record()
90
+ torch.cuda.synchronize()
91
+ ret += [start_event.elapsed_time(end_event) / n_repeat]
92
+ return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)
93
+
94
+
95
+ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
96
+ """
97
+ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
98
+ the 20-th and 80-th performance percentile.
99
+
100
+ :param fn: Function to benchmark
101
+ :type fn: Callable
102
+ :param warmup: Warmup time (in ms)
103
+ :type warmup: int
104
+ :param rep: Repetition time (in ms)
105
+ :type rep: int
106
+ :param grad_to_none: Reset the gradient of the provided tensor to None
107
+ :type grad_to_none: torch.tensor, optional
108
+ :param quantiles: Performance percentile to return in addition to the median.
109
+ :type quantiles: list[float], optional
110
+ :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str
111
+ """
112
+ assert return_mode in ["min", "max", "mean", "median", "all"]
113
+ import torch
114
+
115
+ di = runtime.driver.active.get_device_interface()
116
+
117
+ fn()
118
+ di.synchronize()
119
+
120
+ cache = runtime.driver.active.get_empty_cache_for_benchmark()
121
+
122
+ # Estimate the runtime of the function
123
+ start_event = di.Event(enable_timing=True)
124
+ end_event = di.Event(enable_timing=True)
125
+ start_event.record()
126
+ for _ in range(5):
127
+ cache.zero_()
128
+ fn()
129
+ end_event.record()
130
+ di.synchronize()
131
+ estimate_ms = start_event.elapsed_time(end_event) / 5
132
+
133
+ # compute number of warmup and repeat
134
+ n_warmup = max(1, int(warmup / estimate_ms))
135
+ n_repeat = max(1, int(rep / estimate_ms))
136
+ start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
137
+ end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
138
+ # Warm-up
139
+ for _ in range(n_warmup):
140
+ fn()
141
+ # Benchmark
142
+ for i in range(n_repeat):
143
+ # we don't want `fn` to accumulate gradient values
144
+ # if it contains a backward pass. So we clear the
145
+ # provided gradients
146
+ if grad_to_none is not None:
147
+ for x in grad_to_none:
148
+ x.grad = None
149
+ # we clear the L2 cache before each run
150
+ cache.zero_()
151
+ # record time of `fn`
152
+ start_event[i].record()
153
+ fn()
154
+ end_event[i].record()
155
+ # Record clocks
156
+ di.synchronize()
157
+ times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
158
+ return _summarize_statistics(times, quantiles, return_mode)
159
+
160
+
161
+ def assert_close(x, y, atol=None, rtol=None, err_msg=''):
162
+ """
163
+ Asserts that two inputs are close within a certain tolerance.
164
+
165
+ :param x: The first input.
166
+ :type x: scala, list, numpy.ndarray, or torch.Tensor
167
+ :param y: The second input.
168
+ :type y: scala, list, numpy.ndarray, or torch.Tensor
169
+ :param atol: The absolute tolerance. Default value is 1e-2.
170
+ :type atol: float, optional
171
+ :param rtol: The relative tolerance. Default value is 0.
172
+ :type rtol: float, optional
173
+ :param err_msg: The error message to use if the assertion fails.
174
+ :type err_msg: str
175
+ """
176
+ import numpy as np
177
+ import torch
178
+
179
+ # canonicalize arguments to be tensors
180
+ if not isinstance(x, torch.Tensor):
181
+ x = torch.tensor(x)
182
+ if not isinstance(y, torch.Tensor):
183
+ y = torch.tensor(y)
184
+ # absolute tolerance
185
+ if atol is None:
186
+ atol = 1e-2
187
+ atol = atol(x.dtype) if callable(atol) else atol
188
+ # relative tolerance hook
189
+ if rtol is None:
190
+ rtol = 0.
191
+ rtol = rtol(x.dtype) if callable(rtol) else rtol
192
+ # we use numpy instead of pytorch
193
+ # as it seems more memory efficient
194
+ # pytorch tends to oom on large tensors
195
+ if isinstance(x, torch.Tensor):
196
+ if x.dtype == torch.bfloat16:
197
+ x = x.float()
198
+ x = x.cpu().detach().numpy()
199
+ if isinstance(y, torch.Tensor):
200
+ if y.dtype == torch.bfloat16:
201
+ y = y.float()
202
+ y = y.cpu().detach().numpy()
203
+ # we handle size==1 case separately as we can
204
+ # provide better error message there
205
+ if x.size > 1 or y.size > 1:
206
+ np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
207
+ return
208
+ if not np.allclose(x, y, atol=atol, rtol=rtol):
209
+ raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
210
+
211
+
212
+ class Benchmark:
213
+ """
214
+ This class is used by the :code:`perf_report` function to generate line plots with a concise API.
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ x_names: List[str],
220
+ x_vals: List[Any],
221
+ line_arg: str,
222
+ line_vals: List[Any],
223
+ line_names: List[str],
224
+ plot_name: str,
225
+ args: Dict[str, Any],
226
+ xlabel: str = '',
227
+ ylabel: str = '',
228
+ x_log: bool = False,
229
+ y_log: bool = False,
230
+ styles=None,
231
+ ):
232
+ """
233
+ Constructor.
234
+ x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
235
+ of scalars and there are multiple x_names, all arguments will have the same value.
236
+ If x_vals is a list of tuples/lists, each element should have the same length as
237
+ x_names.
238
+
239
+ :param x_names: Name of the arguments that should appear on the x axis of the plot.
240
+ :type x_names: List[str]
241
+ :param x_vals: List of values to use for the arguments in :code:`x_names`.
242
+ :type x_vals: List[Any]
243
+ :param line_arg: Argument name for which different values correspond to different lines in the plot.
244
+ :type line_arg: str
245
+ :param line_vals: List of values to use for the arguments in :code:`line_arg`.
246
+ :type line_vals: List[Any]
247
+ :param line_names: Label names for the different lines.
248
+ :type line_names: List[str]
249
+ :param plot_name: Name of the plot.
250
+ :type plot_name: str
251
+ :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
252
+ :type args: Dict[str, Any]
253
+ :param xlabel: Label for the x axis of the plot.
254
+ :type xlabel: str, optional
255
+ :param ylabel: Label for the y axis of the plot.
256
+ :type ylabel: str, optional
257
+ :param x_log: Whether the x axis should be log scale.
258
+ :type x_log: bool, optional
259
+ :param y_log: Whether the y axis should be log scale.
260
+ :type y_log: bool, optional
261
+ :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle.
262
+ :type styles: list[tuple[str, str]]
263
+ """
264
+ self.x_names = x_names
265
+ self.x_vals = x_vals
266
+ self.x_log = x_log
267
+ self.line_arg = line_arg
268
+ self.line_vals = line_vals
269
+ self.line_names = line_names
270
+ self.y_log = y_log
271
+ self.styles = styles
272
+ # plot info
273
+ self.xlabel = xlabel
274
+ self.ylabel = ylabel
275
+ self.plot_name = plot_name
276
+ self.args = args
277
+
278
+
279
+ class Mark:
280
+
281
+ def __init__(self, fn, benchmarks):
282
+ self.fn = fn
283
+ self.benchmarks = benchmarks
284
+
285
+ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
286
+ save_precision=6, **kwrags):
287
+ import os
288
+
289
+ import matplotlib.pyplot as plt
290
+ import pandas as pd
291
+ y_mean = bench.line_names
292
+ y_min = [f'{x}-min' for x in bench.line_names]
293
+ y_max = [f'{x}-max' for x in bench.line_names]
294
+ x_names = list(bench.x_names)
295
+ df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
296
+ for x in bench.x_vals:
297
+ # x can be a single value or a sequence of values.
298
+ if not isinstance(x, (list, tuple)):
299
+ x = [x for _ in x_names]
300
+
301
+ if len(x) != len(x_names):
302
+ raise ValueError(f"Expected {len(x_names)} values, got {x}")
303
+ x_args = dict(zip(x_names, x))
304
+
305
+ row_mean, row_min, row_max = [], [], []
306
+ for y in bench.line_vals:
307
+ ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
308
+ try:
309
+ y_mean, y_min, y_max = ret
310
+ except TypeError:
311
+ y_mean, y_min, y_max = ret, None, None
312
+ row_mean += [y_mean]
313
+ row_min += [y_min]
314
+ row_max += [y_max]
315
+ df.loc[len(df)] = list(x) + row_mean + row_min + row_max
316
+
317
+ if bench.plot_name:
318
+ plt.figure()
319
+ ax = plt.subplot()
320
+ # Plot first x value on x axis if there are multiple.
321
+ first_x = x_names[0]
322
+ for i, y in enumerate(bench.line_names):
323
+ y_min, y_max = df[y + '-min'], df[y + '-max']
324
+ col = bench.styles[i][0] if bench.styles else None
325
+ sty = bench.styles[i][1] if bench.styles else None
326
+ ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
327
+ if not y_min.isnull().all() and not y_max.isnull().all():
328
+ y_min = y_min.astype(float)
329
+ y_max = y_max.astype(float)
330
+ ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
331
+ ax.legend()
332
+ ax.set_xlabel(bench.xlabel or first_x)
333
+ ax.set_ylabel(bench.ylabel)
334
+ # ax.set_title(bench.plot_name)
335
+ ax.set_xscale("log" if bench.x_log else "linear")
336
+ ax.set_yscale("log" if bench.y_log else "linear")
337
+ if show_plots:
338
+ plt.show()
339
+ if save_path:
340
+ plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
341
+ df = df[x_names + bench.line_names]
342
+ if diff_col and df.shape[1] == 2:
343
+ col0, col1 = df.columns.tolist()
344
+ df['Diff'] = df[col1] - df[col0]
345
+
346
+ if print_data:
347
+ print(bench.plot_name + ':')
348
+ print(df.to_string())
349
+ if save_path:
350
+ df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
351
+ index=False)
352
+ return df
353
+
354
+ def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
355
+ has_single_bench = isinstance(self.benchmarks, Benchmark)
356
+ benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
357
+ result_dfs = []
358
+ if save_path:
359
+ # Create directory if it doesn't exist
360
+ os.makedirs(save_path, exist_ok=True)
361
+ html = open(os.path.join(save_path, "results.html"), "w")
362
+ html.write("<html><body>\n")
363
+ for bench in benchmarks:
364
+ result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
365
+ if save_path:
366
+ html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
367
+ if save_path:
368
+ html.write("</body></html>\n")
369
+ html.close()
370
+ if return_df:
371
+ if has_single_bench:
372
+ return result_dfs[0]
373
+ else:
374
+ return result_dfs
375
+ return None
376
+
377
+
378
+ def perf_report(benchmarks):
379
+ """
380
+ Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
381
+
382
+ :param benchmarks: Benchmarking configurations.
383
+ :type benchmarks: List of :class:`Benchmark`
384
+ """
385
+ wrapper = lambda fn: Mark(fn, benchmarks)
386
+ return wrapper
387
+
388
+
389
+ def get_dram_gbps(device=None):
390
+ ''' return DRAM bandwidth in GB/s '''
391
+ import torch
392
+
393
+ from .runtime import driver
394
+ if not device:
395
+ device = torch.cuda.current_device()
396
+ mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
397
+ bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"]
398
+ bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
399
+ return bw_gbps
400
+
401
+
402
+ def get_max_tensorcore_tflops(dtype, clock_rate, device=None):
403
+ import torch
404
+
405
+ from .runtime import driver
406
+ if not device:
407
+ device = torch.cuda.current_device()
408
+
409
+ num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
410
+ capability = torch.cuda.get_device_capability(device)
411
+ if capability[0] < 8:
412
+ assert dtype == torch.float16
413
+ ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
414
+ else:
415
+ if dtype in [torch.float32, torch.int32]:
416
+ ops_per_sub_core = 256
417
+ elif dtype in [torch.float16, torch.bfloat16, torch.int16]:
418
+ ops_per_sub_core = 512
419
+ elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
420
+ ops_per_sub_core = 1024
421
+ else:
422
+ raise RuntimeError("dtype not supported")
423
+ tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
424
+ return tflops
425
+
426
+
427
+ # create decorator that wraps test function into
428
+ # a cuda-memcheck system call
429
+
430
+
431
+ def cuda_memcheck(**target_kwargs):
432
+
433
+ def decorator(test_fn):
434
+
435
+ @functools.wraps(test_fn)
436
+ def wrapper(*args, **kwargs):
437
+ import psutil
438
+ ppid_name = psutil.Process(os.getppid()).name()
439
+ run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
440
+ if run_cuda_memcheck and ppid_name != "cuda-memcheck":
441
+ path = os.path.realpath(test_fn.__globals__["__file__"])
442
+ # get path of current file
443
+ env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
444
+ assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
445
+ test_id = kwargs['request'].node.callspec.id
446
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
447
+ out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
448
+ assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
449
+ assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
450
+ else:
451
+ test_fn(*args, **kwargs)
452
+
453
+ return wrapper
454
+
455
+ return decorator
456
+
457
+
458
+ @contextmanager
459
+ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
460
+ try:
461
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
462
+ subprocess.check_output([
463
+ "nvidia-smi",
464
+ "-i",
465
+ "0",
466
+ f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
467
+ ])
468
+ subprocess.check_output([
469
+ "nvidia-smi",
470
+ "-i",
471
+ "0",
472
+ f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
473
+ ])
474
+ cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
475
+ cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
476
+ assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
477
+ assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
478
+ tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
479
+ gbps = 640 * 2 * ref_mem_clock * 1e-3
480
+ yield tflops, gbps
481
+ finally:
482
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
483
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
484
+ subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
485
+
486
+
487
+ def get_max_simd_tflops(dtype, clock_rate, device=None):
488
+ import torch
489
+
490
+ from .runtime import driver
491
+ if not device:
492
+ device = torch.cuda.current_device()
493
+
494
+ num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
495
+ capability = torch.cuda.get_device_capability()
496
+ if capability[0] < 8:
497
+ if dtype == torch.float32:
498
+ ops_per_sub_core = 32 # 2*16
499
+ elif dtype == torch.float16:
500
+ ops_per_sub_core = 64
501
+ else:
502
+ raise RuntimeError("dtype not supported")
503
+ else:
504
+ if dtype == torch.float32:
505
+ ops_per_sub_core = 32
506
+ elif dtype in [torch.float16, torch.bfloat16]:
507
+ ops_per_sub_core = 64
508
+ else:
509
+ raise RuntimeError("dtype not supported")
510
+ tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
511
+ return tflops
File without changes