triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
@@ -1,5 +1,6 @@
1
1
  from triton.backends.compiler import BaseBackend, GPUTarget
2
2
  from triton._C.libtriton import ir, passes, llvm, nvidia
3
+ from triton.runtime.errors import PTXASError
3
4
 
4
5
  from dataclasses import dataclass
5
6
  import functools
@@ -12,16 +13,26 @@ import signal
12
13
  import os
13
14
  import subprocess
14
15
  from pathlib import Path
16
+ import sysconfig
15
17
 
16
18
 
17
19
  def min_dot_size(target: GPUTarget):
18
- return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16)
20
+
21
+ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
22
+ lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
23
+ rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
24
+ assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
25
+ if lhs_bitwidth == 8:
26
+ return (16, 16, 32)
27
+ else:
28
+ return (16, 16, 16)
29
+
30
+ return check_dot_compatibility
19
31
 
20
32
 
21
33
  @functools.lru_cache()
22
34
  def _path_to_binary(binary: str):
23
- if os.name == "nt":
24
- binary += ".exe"
35
+ binary += sysconfig.get_config_var("EXE")
25
36
  paths = [
26
37
  os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
27
38
  os.path.join(os.path.dirname(__file__), "bin", binary),
@@ -32,19 +43,31 @@ def _path_to_binary(binary: str):
32
43
  if cuda_bin_path:
33
44
  paths += [os.path.join(cuda_bin_path, binary)]
34
45
 
35
- for bin in paths:
36
- if os.path.exists(bin) and os.path.isfile(bin):
37
- result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
46
+ for path in paths:
47
+ if os.path.exists(path) and os.path.isfile(path):
48
+ result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
38
49
  if result is not None:
39
50
  version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
40
51
  if version is not None:
41
- return bin, version.group(1)
52
+ return path, version.group(1)
42
53
  raise RuntimeError(f"Cannot find {binary}")
43
54
 
44
55
 
45
56
  @functools.lru_cache()
46
- def get_ptxas_version():
47
- version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8")
57
+ def get_ptxas(arch: int):
58
+ if os.name == "nt":
59
+ name = "ptxas"
60
+ else:
61
+ name = "ptxas-blackwell" if arch >= 100 else "ptxas"
62
+ return _path_to_binary(name)
63
+
64
+
65
+ @functools.lru_cache()
66
+ def get_ptxas_version(arch: int):
67
+ mock_ver = os.environ.get('TRITON_MOCK_PTX_VERSION')
68
+ if mock_ver is not None:
69
+ return mock_ver # This is not really a version of ptxas, but it is good enough for testing
70
+ version = subprocess.check_output([get_ptxas(arch)[0], "--version"]).decode("utf-8")
48
71
  return version
49
72
 
50
73
 
@@ -59,7 +82,7 @@ def ptx_get_version(cuda_version) -> int:
59
82
  if minor < 6:
60
83
  return 80 + minor
61
84
  else:
62
- return 79 + minor
85
+ return 80 + minor - 1
63
86
  if major == 11:
64
87
  return 70 + minor
65
88
  if major == 10:
@@ -67,24 +90,24 @@ def ptx_get_version(cuda_version) -> int:
67
90
  raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
68
91
 
69
92
 
70
- def get_ptx_version_from_options(options):
93
+ def get_ptx_version_from_options(options, arch: int):
71
94
  ptx_version = options.ptx_version
72
95
  if ptx_version is None:
73
- _, cuda_version = _path_to_binary("ptxas")
96
+ _, cuda_version = get_ptxas(arch)
74
97
  ptx_version = ptx_get_version(cuda_version)
75
98
  return ptx_version
76
99
 
77
100
 
78
101
  @functools.lru_cache()
79
- def get_features(options):
80
- ptx_version = get_ptx_version_from_options(options)
102
+ def get_features(options, arch: int):
103
+ ptx_version = get_ptx_version_from_options(options, arch)
81
104
 
82
- # PTX 8.3 is the max version supported by llvm 3a83162168.
105
+ # PTX 8.6 is the max version supported by llvm c1188642.
83
106
  #
84
107
  # To check if a newer PTX version is supported, increase this value
85
108
  # and run a test. If it's not supported, LLVM will print a warning
86
109
  # like "+ptx8.4 is not a recognized feature for this target".
87
- llvm_ptx_version = min(83, ptx_version)
110
+ llvm_ptx_version = min(86, ptx_version)
88
111
  features = f'+ptx{llvm_ptx_version}'
89
112
  return features
90
113
 
@@ -95,6 +118,12 @@ def file_hash(path):
95
118
  return hashlib.sha256(f.read()).hexdigest()
96
119
 
97
120
 
121
+ def sm_arch_from_capability(capability: int):
122
+ # TODO: Handle non-"a" sms
123
+ suffix = "a" if capability >= 90 else ""
124
+ return f"sm_{capability}{suffix}"
125
+
126
+
98
127
  # The file may be accessed in parallel
99
128
  def try_remove(path):
100
129
  if os.path.exists(path):
@@ -110,16 +139,13 @@ class CUDAOptions:
110
139
  num_warps: int = 4
111
140
  num_ctas: int = 1
112
141
  num_stages: int = 3
113
- num_buffers_warp_spec: int = 0
114
- num_consumer_groups: int = 0
115
- reg_dec_producer: int = 0
116
- reg_inc_consumer: int = 0
117
142
  # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
118
143
  # maximum number of 32-bit registers used by one thread.
119
144
  maxnreg: Optional[int] = None
120
145
  cluster_dims: tuple = (1, 1, 1)
121
146
  ptx_version: int = None
122
147
  enable_fp_fusion: bool = True
148
+ launch_cooperative_grid: bool = False
123
149
  supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
124
150
  deprecated_fp8_dtypes: Tuple[str] = ()
125
151
  default_dot_input_precision: str = "tf32"
@@ -129,6 +155,7 @@ class CUDAOptions:
129
155
  debug: bool = False
130
156
  backend_name: str = 'cuda'
131
157
  sanitize_overflow: bool = True
158
+ arch: str = None
132
159
 
133
160
  def __post_init__(self):
134
161
  default_libdir = Path(__file__).parent / 'lib'
@@ -152,27 +179,37 @@ class CUDABackend(BaseBackend):
152
179
  def supports_target(target: GPUTarget):
153
180
  return target.backend == 'cuda'
154
181
 
182
+ def _parse_arch(self, arch):
183
+ pattern = r"^sm(\d+)$"
184
+ match = re.fullmatch(pattern, arch)
185
+ if not match:
186
+ raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
187
+ return int(match.group(1))
188
+
155
189
  def __init__(self, target: GPUTarget) -> None:
156
190
  super().__init__(target)
157
- self.capability = target.arch
158
- assert isinstance(self.capability, int)
159
191
  self.binary_ext = "cubin"
160
192
 
161
193
  def parse_options(self, opts) -> Any:
162
- args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
194
+ args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", f"sm{self.target.arch}")}
195
+ args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
196
+ capability = int(self._parse_arch(args["arch"]))
197
+
163
198
  if "supported_fp8_dtypes" not in args:
164
199
  supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
165
- if self.capability >= 89:
200
+ if capability >= 89:
166
201
  supported_fp8_dtypes.add("fp8e4nv")
167
202
  args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
168
203
 
169
204
  if "deprecated_fp8_dtypes" not in args:
170
- if self.capability >= 90:
205
+ if capability >= 90:
171
206
  args["deprecated_fp8_dtypes"] = ("fp8e4b15", )
172
207
 
173
208
  if "enable_fp_fusion" not in args:
174
209
  args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
175
- args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
210
+
211
+ args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
212
+
176
213
  return CUDAOptions(**args)
177
214
 
178
215
  def pack_metadata(self, metadata):
@@ -185,12 +222,13 @@ class CUDABackend(BaseBackend):
185
222
  metadata.cluster_dims[2],
186
223
  )
187
224
 
188
- def get_codegen_implementation(self):
225
+ def get_codegen_implementation(self, options):
189
226
  import triton.language.extra.cuda as cuda
227
+ capability = int(self._parse_arch(options.arch))
190
228
  codegen_fns = {
191
229
  "convert_custom_types":
192
- cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70,
193
- "min_dot_size": min_dot_size(self.target)
230
+ cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
231
+ min_dot_size(self.target)
194
232
  }
195
233
  return codegen_fns
196
234
 
@@ -207,11 +245,10 @@ class CUDABackend(BaseBackend):
207
245
  pm.enable_debug()
208
246
  passes.common.add_inliner(pm)
209
247
  passes.ttir.add_rewrite_tensor_pointer(pm)
210
- passes.ttir.add_combine(pm)
211
248
  passes.common.add_canonicalizer(pm)
249
+ passes.ttir.add_combine(pm)
212
250
  passes.ttir.add_reorder_broadcast(pm)
213
251
  passes.common.add_cse(pm)
214
- passes.common.add_licm(pm)
215
252
  passes.common.add_symbol_dce(pm)
216
253
  passes.ttir.add_loop_unroll(pm)
217
254
  pm.run(mod)
@@ -224,14 +261,8 @@ class CUDABackend(BaseBackend):
224
261
  cluster_info.clusterDimX = opt.cluster_dims[0]
225
262
  cluster_info.clusterDimY = opt.cluster_dims[1]
226
263
  cluster_info.clusterDimZ = opt.cluster_dims[2]
227
- # Set up Diagnostic
228
- if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
229
- srcMgr = llvm.source_mgr()
230
- diag = ir.source_mgr_diag(srcMgr, mod.context)
231
- mod.context.printOpOnDiagnostic(True)
232
- # TTIR -> TTGIR
233
264
  pm = ir.pass_manager(mod.context)
234
- pm.enable_debug()
265
+ dump_enabled = pm.enable_debug()
235
266
  passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
236
267
  # optimize TTGIR
237
268
  passes.ttgpuir.add_coalesce(pm)
@@ -245,18 +276,29 @@ class CUDABackend(BaseBackend):
245
276
  passes.ttgpuir.add_remove_layout_conversions(pm)
246
277
  passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
247
278
  passes.common.add_cse(pm)
248
- if capability // 10 >= 8:
279
+ if capability // 10 in [8, 9]:
280
+ passes.ttgpuir.add_fuse_nested_loops(pm)
281
+ passes.common.add_canonicalizer(pm)
282
+ passes.common.add_licm(pm)
249
283
  passes.ttgpuir.add_optimize_accumulator_init(pm)
284
+ passes.common.add_canonicalizer(pm)
250
285
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
251
- passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
252
- passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
253
- passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
254
- passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
255
- opt.reg_dec_producer, opt.reg_inc_consumer)
256
- passes.ttgpuir.add_pipeline(pm, opt.num_stages)
257
- passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
286
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
287
+ elif capability // 10 >= 10:
288
+ passes.ttgpuir.add_fuse_nested_loops(pm)
289
+ passes.common.add_canonicalizer(pm)
290
+ passes.common.add_licm(pm)
291
+ passes.ttgpuir.add_optimize_accumulator_init(pm)
292
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
293
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
294
+ nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
295
+ nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm)
296
+ passes.common.add_canonicalizer(pm)
297
+ else:
298
+ passes.common.add_licm(pm)
258
299
  passes.ttgpuir.add_prefetch(pm)
259
300
  passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
301
+ passes.ttgpuir.add_coalesce_async_copy(pm)
260
302
  passes.ttgpuir.add_remove_layout_conversions(pm)
261
303
  passes.ttgpuir.add_reduce_data_duplication(pm)
262
304
  passes.ttgpuir.add_reorder_instructions(pm)
@@ -270,31 +312,26 @@ class CUDABackend(BaseBackend):
270
312
  metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
271
313
  return mod
272
314
 
273
- @staticmethod
274
- def make_llir(src, metadata, options, capability):
275
- ptx_version = get_ptx_version_from_options(options)
315
+ def make_llir(self, src, metadata, options, capability):
316
+ ptx_version = get_ptx_version_from_options(options, self.target.arch)
276
317
 
277
- # warp-specialization mutates num_warps
278
- num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
279
- if num_warp_groups is not None:
280
- metadata["num_warps"] *= num_warp_groups
281
318
  mod = src
282
319
  # TritonGPU -> LLVM-IR (MLIR)
283
320
  pm = ir.pass_manager(mod.context)
284
321
  pm.enable_debug()
285
- # Set up Diagnostic
286
- if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
287
- srcMgr = llvm.source_mgr()
288
- diag = ir.source_mgr_diag(srcMgr, mod.context)
289
- mod.context.printOpOnDiagnostic(True)
290
- nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
322
+
323
+ nvidia.passes.ttnvgpuir.add_lower_mma(pm)
291
324
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
325
+ passes.ttgpuir.add_allocate_warp_groups(pm)
292
326
  passes.convert.add_scf_to_cf(pm)
293
- passes.convert.add_index_to_llvmir(pm)
294
327
  passes.ttgpuir.add_allocate_shared_memory(pm)
328
+ nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
329
+ passes.ttgpuir.add_allocate_global_scratch_memory(pm)
295
330
  nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
331
+ passes.common.add_canonicalizer(pm)
332
+ passes.common.add_cse(pm)
296
333
  nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
297
- passes.convert.add_arith_to_llvmir(pm)
334
+ nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
298
335
  passes.common.add_canonicalizer(pm)
299
336
  passes.common.add_cse(pm)
300
337
  passes.common.add_symbol_dce(pm)
@@ -304,10 +341,12 @@ class CUDABackend(BaseBackend):
304
341
  # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
305
342
  llvm.init_targets()
306
343
  context = llvm.context()
307
-
344
+ if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
345
+ raise RuntimeError(
346
+ "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
308
347
  llvm_mod = llvm.to_module(mod, context)
309
- proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
310
- features = get_features(options)
348
+ proc = sm_arch_from_capability(capability)
349
+ features = get_features(options, self.target.arch)
311
350
  triple = 'nvptx64-nvidia-cuda'
312
351
  llvm.attach_datalayout(llvm_mod, triple, proc, features)
313
352
  nvidia.set_nvvm_reflect_ftz(llvm_mod)
@@ -325,19 +364,25 @@ class CUDABackend(BaseBackend):
325
364
  llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
326
365
 
327
366
  # Get some metadata
328
- metadata["shared"] = src.get_int_attr("triton_gpu.shared")
367
+ # warp-specialization mutates num_warps
368
+ total_num_warps = src.get_int_attr("ttg.total-num-warps")
369
+ if total_num_warps is not None:
370
+ metadata["num_warps"] = total_num_warps
371
+ metadata["shared"] = src.get_int_attr("ttg.shared")
372
+ metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
373
+ metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
374
+ metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
329
375
  ret = str(llvm_mod)
330
376
  del llvm_mod
331
377
  del context
332
378
  return ret
333
379
 
334
- @staticmethod
335
- def make_ptx(src, metadata, opt, capability):
336
- ptx_version = get_ptx_version_from_options(opt)
380
+ def make_ptx(self, src, metadata, opt, capability):
381
+ ptx_version = get_ptx_version_from_options(opt, self.target.arch)
337
382
 
338
383
  triple = 'nvptx64-nvidia-cuda'
339
- proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
340
- features = get_features(opt)
384
+ proc = sm_arch_from_capability(capability)
385
+ features = get_features(opt, self.target.arch)
341
386
  ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
342
387
  # Find kernel names (there should only be one)
343
388
  names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
@@ -346,6 +391,7 @@ class CUDABackend(BaseBackend):
346
391
  # post-process
347
392
  ptx_version = f'{ptx_version//10}.{ptx_version%10}'
348
393
  ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
394
+ ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
349
395
  # Remove the debug flag that prevents ptxas from optimizing the code
350
396
  ret = re.sub(r",\s*debug|debug,\s*", "", ret)
351
397
  if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
@@ -353,9 +399,8 @@ class CUDABackend(BaseBackend):
353
399
  print(ret)
354
400
  return ret
355
401
 
356
- @staticmethod
357
- def make_cubin(src, metadata, opt, capability):
358
- ptxas, _ = _path_to_binary("ptxas")
402
+ def make_cubin(self, src, metadata, opt, capability):
403
+ ptxas, _ = get_ptxas(self.target.arch)
359
404
  # On Windows, we need to set delete=False, close the temp file before reading it, and manually remove it
360
405
  with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
361
406
  tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
@@ -363,23 +408,19 @@ class CUDABackend(BaseBackend):
363
408
  fsrc.close()
364
409
  fbin = fsrc.name + '.o'
365
410
 
366
- line_info = [] if os.environ.get('TRITON_DISABLE_LINE_INFO') else ['-lineinfo']
411
+ line_info = ["-lineinfo", "-suppress-debug-info"] if os.environ.get("TRITON_DISABLE_LINE_INFO",
412
+ "0") == "1" else ["-lineinfo"]
367
413
  fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
368
- suffix = 'a' if capability == 90 else ''
414
+ arch = sm_arch_from_capability(capability)
369
415
  opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
370
- ptxas_cmd = [
371
- ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin
372
- ]
416
+ ptxas_cmd = [ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name={arch}', fsrc.name, '-o', fbin]
373
417
  try:
374
418
  subprocess.run(ptxas_cmd, check=True, close_fds=False, stdout=flog, stderr=flog)
375
- try_remove(fsrc.name)
376
419
  flog.close()
377
- try_remove(flog.name)
378
420
  except subprocess.CalledProcessError as e:
379
421
  flog.close()
380
422
  with open(flog.name) as log_file:
381
423
  log = log_file.read()
382
- try_remove(flog.name)
383
424
 
384
425
  if e.returncode == 255:
385
426
  error = 'Internal Triton PTX codegen error'
@@ -388,9 +429,12 @@ class CUDABackend(BaseBackend):
388
429
  else:
389
430
  error = f'`ptxas` failed with error code {e.returncode}'
390
431
 
391
- raise RuntimeError(f'{error}\n'
392
- f'`ptxas` stderr:\n{log}\n'
393
- f'Repro command: {" ".join(ptxas_cmd)}\n')
432
+ raise PTXASError(f"{error}\n"
433
+ f"`ptxas` stderr:\n{log}\n"
434
+ f'Repro command: {" ".join(ptxas_cmd)}\n')
435
+ finally:
436
+ try_remove(fsrc.name)
437
+ try_remove(flog.name)
394
438
 
395
439
  with open(fbin, 'rb') as f:
396
440
  cubin = f.read()
@@ -398,13 +442,14 @@ class CUDABackend(BaseBackend):
398
442
  return cubin
399
443
 
400
444
  def add_stages(self, stages, options):
445
+ capability = self._parse_arch(options.arch)
401
446
  stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
402
- stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
403
- stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
404
- stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
405
- stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
447
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
448
+ stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
449
+ stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
450
+ stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
406
451
 
407
452
  @functools.lru_cache()
408
453
  def hash(self):
409
- version = get_ptxas_version()
410
- return f'{version}-{self.capability}'
454
+ version = get_ptxas_version(self.target.arch)
455
+ return f'{version}-{self.target.arch}'
@@ -10,7 +10,6 @@
10
10
  #include <stdbool.h>
11
11
  #define PY_SSIZE_T_CLEAN
12
12
  #include <Python.h>
13
- // #include <stdatomic.h>
14
13
 
15
14
  // Raises a Python exception and returns false if code is not CUDA_SUCCESS.
16
15
  static bool gpuAssert(CUresult code, const char *file, int line) {