triton-windows 3.2.0.post12__cp312-cp312-win_amd64.whl → 3.3.0a0.post12__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from triton.backends.compiler import BaseBackend, GPUTarget
|
|
2
2
|
from triton._C.libtriton import ir, passes, llvm, nvidia
|
|
3
|
+
from triton.runtime.errors import PTXASError
|
|
3
4
|
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
import functools
|
|
@@ -12,16 +13,26 @@ import signal
|
|
|
12
13
|
import os
|
|
13
14
|
import subprocess
|
|
14
15
|
from pathlib import Path
|
|
16
|
+
import sysconfig
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
def min_dot_size(target: GPUTarget):
|
|
18
|
-
|
|
20
|
+
|
|
21
|
+
def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
|
|
22
|
+
lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
|
|
23
|
+
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
|
|
24
|
+
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
|
|
25
|
+
if lhs_bitwidth == 8:
|
|
26
|
+
return (16, 16, 32)
|
|
27
|
+
else:
|
|
28
|
+
return (16, 16, 16)
|
|
29
|
+
|
|
30
|
+
return check_dot_compatibility
|
|
19
31
|
|
|
20
32
|
|
|
21
33
|
@functools.lru_cache()
|
|
22
34
|
def _path_to_binary(binary: str):
|
|
23
|
-
|
|
24
|
-
binary += ".exe"
|
|
35
|
+
binary += sysconfig.get_config_var("EXE")
|
|
25
36
|
paths = [
|
|
26
37
|
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
27
38
|
os.path.join(os.path.dirname(__file__), "bin", binary),
|
|
@@ -32,19 +43,31 @@ def _path_to_binary(binary: str):
|
|
|
32
43
|
if cuda_bin_path:
|
|
33
44
|
paths += [os.path.join(cuda_bin_path, binary)]
|
|
34
45
|
|
|
35
|
-
for
|
|
36
|
-
if os.path.exists(
|
|
37
|
-
result = subprocess.check_output([
|
|
46
|
+
for path in paths:
|
|
47
|
+
if os.path.exists(path) and os.path.isfile(path):
|
|
48
|
+
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
38
49
|
if result is not None:
|
|
39
50
|
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
40
51
|
if version is not None:
|
|
41
|
-
return
|
|
52
|
+
return path, version.group(1)
|
|
42
53
|
raise RuntimeError(f"Cannot find {binary}")
|
|
43
54
|
|
|
44
55
|
|
|
45
56
|
@functools.lru_cache()
|
|
46
|
-
def
|
|
47
|
-
|
|
57
|
+
def get_ptxas(arch: int):
|
|
58
|
+
if os.name == "nt":
|
|
59
|
+
name = "ptxas"
|
|
60
|
+
else:
|
|
61
|
+
name = "ptxas-blackwell" if arch >= 100 else "ptxas"
|
|
62
|
+
return _path_to_binary(name)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@functools.lru_cache()
|
|
66
|
+
def get_ptxas_version(arch: int):
|
|
67
|
+
mock_ver = os.environ.get('TRITON_MOCK_PTX_VERSION')
|
|
68
|
+
if mock_ver is not None:
|
|
69
|
+
return mock_ver # This is not really a version of ptxas, but it is good enough for testing
|
|
70
|
+
version = subprocess.check_output([get_ptxas(arch)[0], "--version"]).decode("utf-8")
|
|
48
71
|
return version
|
|
49
72
|
|
|
50
73
|
|
|
@@ -59,7 +82,7 @@ def ptx_get_version(cuda_version) -> int:
|
|
|
59
82
|
if minor < 6:
|
|
60
83
|
return 80 + minor
|
|
61
84
|
else:
|
|
62
|
-
return
|
|
85
|
+
return 80 + minor - 1
|
|
63
86
|
if major == 11:
|
|
64
87
|
return 70 + minor
|
|
65
88
|
if major == 10:
|
|
@@ -67,24 +90,24 @@ def ptx_get_version(cuda_version) -> int:
|
|
|
67
90
|
raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
|
|
68
91
|
|
|
69
92
|
|
|
70
|
-
def get_ptx_version_from_options(options):
|
|
93
|
+
def get_ptx_version_from_options(options, arch: int):
|
|
71
94
|
ptx_version = options.ptx_version
|
|
72
95
|
if ptx_version is None:
|
|
73
|
-
_, cuda_version =
|
|
96
|
+
_, cuda_version = get_ptxas(arch)
|
|
74
97
|
ptx_version = ptx_get_version(cuda_version)
|
|
75
98
|
return ptx_version
|
|
76
99
|
|
|
77
100
|
|
|
78
101
|
@functools.lru_cache()
|
|
79
|
-
def get_features(options):
|
|
80
|
-
ptx_version = get_ptx_version_from_options(options)
|
|
102
|
+
def get_features(options, arch: int):
|
|
103
|
+
ptx_version = get_ptx_version_from_options(options, arch)
|
|
81
104
|
|
|
82
|
-
# PTX 8.
|
|
105
|
+
# PTX 8.6 is the max version supported by llvm c1188642.
|
|
83
106
|
#
|
|
84
107
|
# To check if a newer PTX version is supported, increase this value
|
|
85
108
|
# and run a test. If it's not supported, LLVM will print a warning
|
|
86
109
|
# like "+ptx8.4 is not a recognized feature for this target".
|
|
87
|
-
llvm_ptx_version = min(
|
|
110
|
+
llvm_ptx_version = min(86, ptx_version)
|
|
88
111
|
features = f'+ptx{llvm_ptx_version}'
|
|
89
112
|
return features
|
|
90
113
|
|
|
@@ -95,6 +118,12 @@ def file_hash(path):
|
|
|
95
118
|
return hashlib.sha256(f.read()).hexdigest()
|
|
96
119
|
|
|
97
120
|
|
|
121
|
+
def sm_arch_from_capability(capability: int):
|
|
122
|
+
# TODO: Handle non-"a" sms
|
|
123
|
+
suffix = "a" if capability >= 90 else ""
|
|
124
|
+
return f"sm_{capability}{suffix}"
|
|
125
|
+
|
|
126
|
+
|
|
98
127
|
# The file may be accessed in parallel
|
|
99
128
|
def try_remove(path):
|
|
100
129
|
if os.path.exists(path):
|
|
@@ -110,16 +139,13 @@ class CUDAOptions:
|
|
|
110
139
|
num_warps: int = 4
|
|
111
140
|
num_ctas: int = 1
|
|
112
141
|
num_stages: int = 3
|
|
113
|
-
num_buffers_warp_spec: int = 0
|
|
114
|
-
num_consumer_groups: int = 0
|
|
115
|
-
reg_dec_producer: int = 0
|
|
116
|
-
reg_inc_consumer: int = 0
|
|
117
142
|
# maxnreg corresponds to the ptx parameter .maxnreg, which controls the
|
|
118
143
|
# maximum number of 32-bit registers used by one thread.
|
|
119
144
|
maxnreg: Optional[int] = None
|
|
120
145
|
cluster_dims: tuple = (1, 1, 1)
|
|
121
146
|
ptx_version: int = None
|
|
122
147
|
enable_fp_fusion: bool = True
|
|
148
|
+
launch_cooperative_grid: bool = False
|
|
123
149
|
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
|
|
124
150
|
deprecated_fp8_dtypes: Tuple[str] = ()
|
|
125
151
|
default_dot_input_precision: str = "tf32"
|
|
@@ -129,6 +155,7 @@ class CUDAOptions:
|
|
|
129
155
|
debug: bool = False
|
|
130
156
|
backend_name: str = 'cuda'
|
|
131
157
|
sanitize_overflow: bool = True
|
|
158
|
+
arch: str = None
|
|
132
159
|
|
|
133
160
|
def __post_init__(self):
|
|
134
161
|
default_libdir = Path(__file__).parent / 'lib'
|
|
@@ -152,27 +179,37 @@ class CUDABackend(BaseBackend):
|
|
|
152
179
|
def supports_target(target: GPUTarget):
|
|
153
180
|
return target.backend == 'cuda'
|
|
154
181
|
|
|
182
|
+
def _parse_arch(self, arch):
|
|
183
|
+
pattern = r"^sm(\d+)$"
|
|
184
|
+
match = re.fullmatch(pattern, arch)
|
|
185
|
+
if not match:
|
|
186
|
+
raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
|
|
187
|
+
return int(match.group(1))
|
|
188
|
+
|
|
155
189
|
def __init__(self, target: GPUTarget) -> None:
|
|
156
190
|
super().__init__(target)
|
|
157
|
-
self.capability = target.arch
|
|
158
|
-
assert isinstance(self.capability, int)
|
|
159
191
|
self.binary_ext = "cubin"
|
|
160
192
|
|
|
161
193
|
def parse_options(self, opts) -> Any:
|
|
162
|
-
args = {
|
|
194
|
+
args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", f"sm{self.target.arch}")}
|
|
195
|
+
args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
|
|
196
|
+
capability = int(self._parse_arch(args["arch"]))
|
|
197
|
+
|
|
163
198
|
if "supported_fp8_dtypes" not in args:
|
|
164
199
|
supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
|
|
165
|
-
if
|
|
200
|
+
if capability >= 89:
|
|
166
201
|
supported_fp8_dtypes.add("fp8e4nv")
|
|
167
202
|
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
|
|
168
203
|
|
|
169
204
|
if "deprecated_fp8_dtypes" not in args:
|
|
170
|
-
if
|
|
205
|
+
if capability >= 90:
|
|
171
206
|
args["deprecated_fp8_dtypes"] = ("fp8e4b15", )
|
|
172
207
|
|
|
173
208
|
if "enable_fp_fusion" not in args:
|
|
174
209
|
args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
|
|
175
|
-
|
|
210
|
+
|
|
211
|
+
args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
|
|
212
|
+
|
|
176
213
|
return CUDAOptions(**args)
|
|
177
214
|
|
|
178
215
|
def pack_metadata(self, metadata):
|
|
@@ -185,12 +222,13 @@ class CUDABackend(BaseBackend):
|
|
|
185
222
|
metadata.cluster_dims[2],
|
|
186
223
|
)
|
|
187
224
|
|
|
188
|
-
def get_codegen_implementation(self):
|
|
225
|
+
def get_codegen_implementation(self, options):
|
|
189
226
|
import triton.language.extra.cuda as cuda
|
|
227
|
+
capability = int(self._parse_arch(options.arch))
|
|
190
228
|
codegen_fns = {
|
|
191
229
|
"convert_custom_types":
|
|
192
|
-
cuda.convert_custom_float8_sm80 if
|
|
193
|
-
|
|
230
|
+
cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
|
|
231
|
+
min_dot_size(self.target)
|
|
194
232
|
}
|
|
195
233
|
return codegen_fns
|
|
196
234
|
|
|
@@ -207,11 +245,10 @@ class CUDABackend(BaseBackend):
|
|
|
207
245
|
pm.enable_debug()
|
|
208
246
|
passes.common.add_inliner(pm)
|
|
209
247
|
passes.ttir.add_rewrite_tensor_pointer(pm)
|
|
210
|
-
passes.ttir.add_combine(pm)
|
|
211
248
|
passes.common.add_canonicalizer(pm)
|
|
249
|
+
passes.ttir.add_combine(pm)
|
|
212
250
|
passes.ttir.add_reorder_broadcast(pm)
|
|
213
251
|
passes.common.add_cse(pm)
|
|
214
|
-
passes.common.add_licm(pm)
|
|
215
252
|
passes.common.add_symbol_dce(pm)
|
|
216
253
|
passes.ttir.add_loop_unroll(pm)
|
|
217
254
|
pm.run(mod)
|
|
@@ -224,14 +261,8 @@ class CUDABackend(BaseBackend):
|
|
|
224
261
|
cluster_info.clusterDimX = opt.cluster_dims[0]
|
|
225
262
|
cluster_info.clusterDimY = opt.cluster_dims[1]
|
|
226
263
|
cluster_info.clusterDimZ = opt.cluster_dims[2]
|
|
227
|
-
# Set up Diagnostic
|
|
228
|
-
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
|
|
229
|
-
srcMgr = llvm.source_mgr()
|
|
230
|
-
diag = ir.source_mgr_diag(srcMgr, mod.context)
|
|
231
|
-
mod.context.printOpOnDiagnostic(True)
|
|
232
|
-
# TTIR -> TTGIR
|
|
233
264
|
pm = ir.pass_manager(mod.context)
|
|
234
|
-
pm.enable_debug()
|
|
265
|
+
dump_enabled = pm.enable_debug()
|
|
235
266
|
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
|
|
236
267
|
# optimize TTGIR
|
|
237
268
|
passes.ttgpuir.add_coalesce(pm)
|
|
@@ -245,18 +276,29 @@ class CUDABackend(BaseBackend):
|
|
|
245
276
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
246
277
|
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
247
278
|
passes.common.add_cse(pm)
|
|
248
|
-
if capability // 10
|
|
279
|
+
if capability // 10 in [8, 9]:
|
|
280
|
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
281
|
+
passes.common.add_canonicalizer(pm)
|
|
282
|
+
passes.common.add_licm(pm)
|
|
249
283
|
passes.ttgpuir.add_optimize_accumulator_init(pm)
|
|
284
|
+
passes.common.add_canonicalizer(pm)
|
|
250
285
|
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
251
|
-
passes.ttgpuir.
|
|
252
|
-
|
|
253
|
-
passes.ttgpuir.
|
|
254
|
-
passes.
|
|
255
|
-
|
|
256
|
-
passes.ttgpuir.
|
|
257
|
-
passes.ttgpuir.
|
|
286
|
+
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
|
287
|
+
elif capability // 10 >= 10:
|
|
288
|
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
289
|
+
passes.common.add_canonicalizer(pm)
|
|
290
|
+
passes.common.add_licm(pm)
|
|
291
|
+
passes.ttgpuir.add_optimize_accumulator_init(pm)
|
|
292
|
+
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
|
293
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
294
|
+
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
|
|
295
|
+
nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm)
|
|
296
|
+
passes.common.add_canonicalizer(pm)
|
|
297
|
+
else:
|
|
298
|
+
passes.common.add_licm(pm)
|
|
258
299
|
passes.ttgpuir.add_prefetch(pm)
|
|
259
300
|
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
301
|
+
passes.ttgpuir.add_coalesce_async_copy(pm)
|
|
260
302
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
261
303
|
passes.ttgpuir.add_reduce_data_duplication(pm)
|
|
262
304
|
passes.ttgpuir.add_reorder_instructions(pm)
|
|
@@ -270,31 +312,26 @@ class CUDABackend(BaseBackend):
|
|
|
270
312
|
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
|
|
271
313
|
return mod
|
|
272
314
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
ptx_version = get_ptx_version_from_options(options)
|
|
315
|
+
def make_llir(self, src, metadata, options, capability):
|
|
316
|
+
ptx_version = get_ptx_version_from_options(options, self.target.arch)
|
|
276
317
|
|
|
277
|
-
# warp-specialization mutates num_warps
|
|
278
|
-
num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
|
|
279
|
-
if num_warp_groups is not None:
|
|
280
|
-
metadata["num_warps"] *= num_warp_groups
|
|
281
318
|
mod = src
|
|
282
319
|
# TritonGPU -> LLVM-IR (MLIR)
|
|
283
320
|
pm = ir.pass_manager(mod.context)
|
|
284
321
|
pm.enable_debug()
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
srcMgr = llvm.source_mgr()
|
|
288
|
-
diag = ir.source_mgr_diag(srcMgr, mod.context)
|
|
289
|
-
mod.context.printOpOnDiagnostic(True)
|
|
290
|
-
nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
|
|
322
|
+
|
|
323
|
+
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
|
|
291
324
|
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
325
|
+
passes.ttgpuir.add_allocate_warp_groups(pm)
|
|
292
326
|
passes.convert.add_scf_to_cf(pm)
|
|
293
|
-
passes.convert.add_index_to_llvmir(pm)
|
|
294
327
|
passes.ttgpuir.add_allocate_shared_memory(pm)
|
|
328
|
+
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
|
|
329
|
+
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
|
|
295
330
|
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
|
|
331
|
+
passes.common.add_canonicalizer(pm)
|
|
332
|
+
passes.common.add_cse(pm)
|
|
296
333
|
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
|
|
297
|
-
passes.
|
|
334
|
+
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
|
|
298
335
|
passes.common.add_canonicalizer(pm)
|
|
299
336
|
passes.common.add_cse(pm)
|
|
300
337
|
passes.common.add_symbol_dce(pm)
|
|
@@ -304,10 +341,12 @@ class CUDABackend(BaseBackend):
|
|
|
304
341
|
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
|
|
305
342
|
llvm.init_targets()
|
|
306
343
|
context = llvm.context()
|
|
307
|
-
|
|
344
|
+
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
|
|
345
|
+
raise RuntimeError(
|
|
346
|
+
"Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
|
|
308
347
|
llvm_mod = llvm.to_module(mod, context)
|
|
309
|
-
proc =
|
|
310
|
-
features = get_features(options)
|
|
348
|
+
proc = sm_arch_from_capability(capability)
|
|
349
|
+
features = get_features(options, self.target.arch)
|
|
311
350
|
triple = 'nvptx64-nvidia-cuda'
|
|
312
351
|
llvm.attach_datalayout(llvm_mod, triple, proc, features)
|
|
313
352
|
nvidia.set_nvvm_reflect_ftz(llvm_mod)
|
|
@@ -325,19 +364,25 @@ class CUDABackend(BaseBackend):
|
|
|
325
364
|
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
|
|
326
365
|
|
|
327
366
|
# Get some metadata
|
|
328
|
-
|
|
367
|
+
# warp-specialization mutates num_warps
|
|
368
|
+
total_num_warps = src.get_int_attr("ttg.total-num-warps")
|
|
369
|
+
if total_num_warps is not None:
|
|
370
|
+
metadata["num_warps"] = total_num_warps
|
|
371
|
+
metadata["shared"] = src.get_int_attr("ttg.shared")
|
|
372
|
+
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
|
|
373
|
+
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
|
|
374
|
+
metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
|
|
329
375
|
ret = str(llvm_mod)
|
|
330
376
|
del llvm_mod
|
|
331
377
|
del context
|
|
332
378
|
return ret
|
|
333
379
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
ptx_version = get_ptx_version_from_options(opt)
|
|
380
|
+
def make_ptx(self, src, metadata, opt, capability):
|
|
381
|
+
ptx_version = get_ptx_version_from_options(opt, self.target.arch)
|
|
337
382
|
|
|
338
383
|
triple = 'nvptx64-nvidia-cuda'
|
|
339
|
-
proc =
|
|
340
|
-
features = get_features(opt)
|
|
384
|
+
proc = sm_arch_from_capability(capability)
|
|
385
|
+
features = get_features(opt, self.target.arch)
|
|
341
386
|
ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
|
|
342
387
|
# Find kernel names (there should only be one)
|
|
343
388
|
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
|
|
@@ -346,6 +391,7 @@ class CUDABackend(BaseBackend):
|
|
|
346
391
|
# post-process
|
|
347
392
|
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
|
|
348
393
|
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
|
|
394
|
+
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
|
|
349
395
|
# Remove the debug flag that prevents ptxas from optimizing the code
|
|
350
396
|
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
|
|
351
397
|
if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
|
|
@@ -353,9 +399,8 @@ class CUDABackend(BaseBackend):
|
|
|
353
399
|
print(ret)
|
|
354
400
|
return ret
|
|
355
401
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
ptxas, _ = _path_to_binary("ptxas")
|
|
402
|
+
def make_cubin(self, src, metadata, opt, capability):
|
|
403
|
+
ptxas, _ = get_ptxas(self.target.arch)
|
|
359
404
|
# On Windows, we need to set delete=False, close the temp file before reading it, and manually remove it
|
|
360
405
|
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
|
|
361
406
|
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
|
|
@@ -363,23 +408,19 @@ class CUDABackend(BaseBackend):
|
|
|
363
408
|
fsrc.close()
|
|
364
409
|
fbin = fsrc.name + '.o'
|
|
365
410
|
|
|
366
|
-
line_info = [] if os.environ.get(
|
|
411
|
+
line_info = ["-lineinfo", "-suppress-debug-info"] if os.environ.get("TRITON_DISABLE_LINE_INFO",
|
|
412
|
+
"0") == "1" else ["-lineinfo"]
|
|
367
413
|
fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
|
|
368
|
-
|
|
414
|
+
arch = sm_arch_from_capability(capability)
|
|
369
415
|
opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
|
|
370
|
-
ptxas_cmd = [
|
|
371
|
-
ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin
|
|
372
|
-
]
|
|
416
|
+
ptxas_cmd = [ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name={arch}', fsrc.name, '-o', fbin]
|
|
373
417
|
try:
|
|
374
418
|
subprocess.run(ptxas_cmd, check=True, close_fds=False, stdout=flog, stderr=flog)
|
|
375
|
-
try_remove(fsrc.name)
|
|
376
419
|
flog.close()
|
|
377
|
-
try_remove(flog.name)
|
|
378
420
|
except subprocess.CalledProcessError as e:
|
|
379
421
|
flog.close()
|
|
380
422
|
with open(flog.name) as log_file:
|
|
381
423
|
log = log_file.read()
|
|
382
|
-
try_remove(flog.name)
|
|
383
424
|
|
|
384
425
|
if e.returncode == 255:
|
|
385
426
|
error = 'Internal Triton PTX codegen error'
|
|
@@ -388,9 +429,12 @@ class CUDABackend(BaseBackend):
|
|
|
388
429
|
else:
|
|
389
430
|
error = f'`ptxas` failed with error code {e.returncode}'
|
|
390
431
|
|
|
391
|
-
raise
|
|
392
|
-
|
|
393
|
-
|
|
432
|
+
raise PTXASError(f"{error}\n"
|
|
433
|
+
f"`ptxas` stderr:\n{log}\n"
|
|
434
|
+
f'Repro command: {" ".join(ptxas_cmd)}\n')
|
|
435
|
+
finally:
|
|
436
|
+
try_remove(fsrc.name)
|
|
437
|
+
try_remove(flog.name)
|
|
394
438
|
|
|
395
439
|
with open(fbin, 'rb') as f:
|
|
396
440
|
cubin = f.read()
|
|
@@ -398,13 +442,14 @@ class CUDABackend(BaseBackend):
|
|
|
398
442
|
return cubin
|
|
399
443
|
|
|
400
444
|
def add_stages(self, stages, options):
|
|
445
|
+
capability = self._parse_arch(options.arch)
|
|
401
446
|
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
|
|
402
|
-
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options,
|
|
403
|
-
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options,
|
|
404
|
-
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.
|
|
405
|
-
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.
|
|
447
|
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
|
|
448
|
+
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
|
|
449
|
+
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
|
|
450
|
+
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
|
|
406
451
|
|
|
407
452
|
@functools.lru_cache()
|
|
408
453
|
def hash(self):
|
|
409
|
-
version = get_ptxas_version()
|
|
410
|
-
return f'{version}-{self.
|
|
454
|
+
version = get_ptxas_version(self.target.arch)
|
|
455
|
+
return f'{version}-{self.target.arch}'
|
triton/backends/nvidia/driver.c
CHANGED