triton-windows 3.3.1.post21__cp311-cp311-win_amd64.whl → 3.4.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +143 -46
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +94 -94
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +296 -125
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +73 -9
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +47 -83
  59. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
  60. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
  61. triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
  64. triton/language/_utils.py +0 -21
  65. triton/language/extra/cuda/_experimental_tma.py +0 -106
  66. triton/tools/experimental_descriptor.py +0 -32
  67. triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
  68. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
@@ -1,5 +1,6 @@
1
- from triton.backends.compiler import BaseBackend, GPUTarget
1
+ from triton.backends.compiler import BaseBackend, GPUTarget, Language
2
2
  from triton._C.libtriton import ir, passes, llvm, nvidia
3
+ from triton import knobs
3
4
  from triton.runtime.errors import PTXASError
4
5
 
5
6
  from dataclasses import dataclass
@@ -13,7 +14,6 @@ import signal
13
14
  import os
14
15
  import subprocess
15
16
  from pathlib import Path
16
- import sysconfig
17
17
 
18
18
 
19
19
  def min_dot_size(target: GPUTarget):
@@ -30,46 +30,16 @@ def min_dot_size(target: GPUTarget):
30
30
  return check_dot_compatibility
31
31
 
32
32
 
33
- @functools.lru_cache()
34
- def _path_to_binary(binary: str):
35
- paths = [
36
- os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
37
- ]
38
- binary += sysconfig.get_config_var("EXE")
39
- paths += [
40
- os.path.join(os.path.dirname(__file__), "bin", binary),
41
- ]
42
- if os.name == "nt":
43
- from triton.windows_utils import find_cuda
44
- cuda_bin_path, _, _ = find_cuda()
45
- if cuda_bin_path:
46
- paths += [os.path.join(cuda_bin_path, binary)]
47
-
48
- for path in paths:
49
- if os.path.exists(path) and os.path.isfile(path):
50
- result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
51
- if result is not None:
52
- version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
53
- if version is not None:
54
- return path, version.group(1)
55
- raise RuntimeError(f"Cannot find {binary}")
33
+ def get_ptxas() -> knobs.NvidiaTool:
34
+ return knobs.nvidia.ptxas
56
35
 
57
36
 
58
37
  @functools.lru_cache()
59
- def get_ptxas(arch: int):
60
- if os.name == "nt":
61
- name = "ptxas"
62
- else:
63
- name = "ptxas-blackwell" if arch >= 100 else "ptxas"
64
- return _path_to_binary(name)
65
-
66
-
67
- @functools.lru_cache()
68
- def get_ptxas_version(arch: int):
69
- mock_ver = os.environ.get('TRITON_MOCK_PTX_VERSION')
38
+ def get_ptxas_version():
39
+ mock_ver = knobs.nvidia.mock_ptx_version
70
40
  if mock_ver is not None:
71
41
  return mock_ver # This is not really a version of ptxas, but it is good enough for testing
72
- version = subprocess.check_output([get_ptxas(arch)[0], "--version"]).decode("utf-8")
42
+ version = subprocess.check_output([get_ptxas().path, "--version"]).decode("utf-8")
73
43
  return version
74
44
 
75
45
 
@@ -95,7 +65,7 @@ def ptx_get_version(cuda_version) -> int:
95
65
  def get_ptx_version_from_options(options, arch: int):
96
66
  ptx_version = options.ptx_version
97
67
  if ptx_version is None:
98
- _, cuda_version = get_ptxas(arch)
68
+ cuda_version = get_ptxas().version
99
69
  ptx_version = ptx_get_version(cuda_version)
100
70
  return ptx_version
101
71
 
@@ -141,19 +111,18 @@ class CUDAOptions:
141
111
  num_warps: int = 4
142
112
  num_ctas: int = 1
143
113
  num_stages: int = 3
144
- num_buffers_warp_spec: int = 0
145
- num_consumer_groups: int = 0
146
- reg_dec_producer: int = 0
147
- reg_inc_consumer: int = 0
148
114
  # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
149
115
  # maximum number of 32-bit registers used by one thread.
150
116
  maxnreg: Optional[int] = None
151
117
  cluster_dims: tuple = (1, 1, 1)
152
118
  ptx_version: int = None
119
+ ptx_options: str = None
120
+ ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
153
121
  enable_fp_fusion: bool = True
154
122
  launch_cooperative_grid: bool = False
123
+ launch_pdl: bool = False
155
124
  supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
156
- deprecated_fp8_dtypes: Tuple[str] = ()
125
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
157
126
  default_dot_input_precision: str = "tf32"
158
127
  allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
159
128
  max_num_imprecise_acc_default: bool = None
@@ -167,7 +136,8 @@ class CUDAOptions:
167
136
  default_libdir = Path(__file__).parent / 'lib'
168
137
  extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
169
138
  if not extern_libs.get('libdevice', None):
170
- extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc'))
139
+ extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
140
+
171
141
  object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
172
142
  assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
173
143
  "num_warps must be a power of 2"
@@ -192,12 +162,16 @@ class CUDABackend(BaseBackend):
192
162
  raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
193
163
  return int(match.group(1))
194
164
 
165
+ def get_target_name(self, options) -> str:
166
+ capability = self._parse_arch(options.arch)
167
+ return f"cuda:{capability}"
168
+
195
169
  def __init__(self, target: GPUTarget) -> None:
196
170
  super().__init__(target)
197
171
  self.binary_ext = "cubin"
198
172
 
199
173
  def parse_options(self, opts) -> Any:
200
- args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", f"sm{self.target.arch}")}
174
+ args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
201
175
  args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
202
176
  capability = int(self._parse_arch(args["arch"]))
203
177
 
@@ -205,12 +179,12 @@ class CUDABackend(BaseBackend):
205
179
  supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
206
180
  args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
207
181
 
208
- if "deprecated_fp8_dtypes" not in args:
182
+ if "deprecated_fp8_dot_operand_dtypes" not in args:
209
183
  if capability >= 90:
210
- args["deprecated_fp8_dtypes"] = ("fp8e4b15", )
184
+ args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
211
185
 
212
186
  if "enable_fp_fusion" not in args:
213
- args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
187
+ args["enable_fp_fusion"] = knobs.language.default_fp_fusion
214
188
 
215
189
  args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
216
190
 
@@ -244,11 +218,13 @@ class CUDABackend(BaseBackend):
244
218
  nvidia.load_dialects(ctx)
245
219
 
246
220
  @staticmethod
247
- def make_ttir(mod, metadata, opt):
221
+ def make_ttir(mod, metadata, opt, capability):
248
222
  pm = ir.pass_manager(mod.context)
249
223
  pm.enable_debug()
250
224
  passes.common.add_inliner(pm)
251
225
  passes.ttir.add_rewrite_tensor_pointer(pm)
226
+ if capability // 10 < 9:
227
+ passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
252
228
  passes.common.add_canonicalizer(pm)
253
229
  passes.ttir.add_combine(pm)
254
230
  passes.ttir.add_reorder_broadcast(pm)
@@ -260,6 +236,10 @@ class CUDABackend(BaseBackend):
260
236
 
261
237
  @staticmethod
262
238
  def make_ttgir(mod, metadata, opt, capability):
239
+ # Set maxnreg on all kernels, if it was provided.
240
+ if opt.maxnreg is not None:
241
+ mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
242
+
263
243
  cluster_info = nvidia.ClusterInfo()
264
244
  if opt.cluster_dims is not None:
265
245
  cluster_info.clusterDimX = opt.cluster_dims[0]
@@ -279,56 +259,69 @@ class CUDABackend(BaseBackend):
279
259
  passes.ttgpuir.add_accelerate_matmul(pm)
280
260
  passes.ttgpuir.add_remove_layout_conversions(pm)
281
261
  passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
282
- passes.common.add_cse(pm)
262
+ nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
263
+ passes.ttir.add_loop_aware_cse(pm)
283
264
  if capability // 10 in [8, 9]:
284
265
  passes.ttgpuir.add_fuse_nested_loops(pm)
285
266
  passes.common.add_canonicalizer(pm)
286
- passes.common.add_licm(pm)
287
- passes.ttgpuir.add_optimize_accumulator_init(pm)
267
+ passes.ttir.add_triton_licm(pm)
288
268
  passes.common.add_canonicalizer(pm)
289
269
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
290
- passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
291
- passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
292
- passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
293
- passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
294
- opt.reg_dec_producer, opt.reg_inc_consumer)
270
+ nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
271
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
272
+ passes.ttgpuir.add_schedule_loops(pm)
295
273
  passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
296
- passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups)
297
- passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
298
274
  elif capability // 10 >= 10:
299
275
  passes.ttgpuir.add_fuse_nested_loops(pm)
300
276
  passes.common.add_canonicalizer(pm)
301
- passes.common.add_licm(pm)
277
+ passes.ttir.add_triton_licm(pm)
302
278
  passes.ttgpuir.add_optimize_accumulator_init(pm)
303
- passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
304
- passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
305
- passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
306
- passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
307
- opt.reg_dec_producer, opt.reg_inc_consumer)
279
+ passes.ttgpuir.add_hoist_tmem_alloc(pm)
280
+ nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
281
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
282
+ passes.ttgpuir.add_schedule_loops(pm)
283
+ passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
308
284
  passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
309
285
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
310
- nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
311
- nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm)
312
- passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
313
- passes.common.add_canonicalizer(pm)
286
+ nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
314
287
  else:
315
- passes.common.add_licm(pm)
288
+ passes.ttir.add_triton_licm(pm)
289
+ passes.common.add_canonicalizer(pm)
290
+ passes.ttir.add_loop_aware_cse(pm)
316
291
  passes.ttgpuir.add_prefetch(pm)
317
292
  passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
318
293
  passes.ttgpuir.add_coalesce_async_copy(pm)
294
+ nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
319
295
  passes.ttgpuir.add_remove_layout_conversions(pm)
296
+ nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
320
297
  passes.ttgpuir.add_reduce_data_duplication(pm)
321
298
  passes.ttgpuir.add_reorder_instructions(pm)
322
- passes.common.add_cse(pm)
299
+ passes.ttir.add_loop_aware_cse(pm)
323
300
  passes.common.add_symbol_dce(pm)
324
301
  if capability // 10 >= 9:
325
- nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
326
302
  nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
303
+ nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
304
+ passes.common.add_sccp(pm)
327
305
  passes.common.add_canonicalizer(pm)
328
- if capability // 10 >= 9:
329
- passes.ttgpuir.add_ws_canonicalization(pm, opt.num_consumer_groups)
330
306
  pm.run(mod)
331
307
  metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
308
+ tensordesc_meta = mod.get_tensordesc_metadata()
309
+ metadata["tensordesc_meta"] = tensordesc_meta
310
+ return mod
311
+
312
+ def ttgir_opt(self, src, metadata, options, capability):
313
+ mod = src
314
+ pm = ir.pass_manager(mod.context)
315
+ pm.enable_debug()
316
+
317
+ passes.ttgpuir.add_inliner(pm)
318
+ passes.common.add_sccp(pm)
319
+ passes.ttir.add_loop_aware_cse(pm)
320
+ passes.ttgpuir.add_canonicalizer(pm)
321
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
322
+
323
+ pm.run(mod)
324
+ metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
332
325
  return mod
333
326
 
334
327
  def make_llir(self, src, metadata, options, capability):
@@ -354,28 +347,23 @@ class CUDABackend(BaseBackend):
354
347
  passes.common.add_canonicalizer(pm)
355
348
  passes.common.add_cse(pm)
356
349
  passes.common.add_symbol_dce(pm)
357
- if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
350
+ if not knobs.compilation.disable_line_info:
358
351
  passes.llvmir.add_di_scope(pm)
359
352
  pm.run(mod)
360
353
  # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
361
354
  llvm.init_targets()
362
355
  context = llvm.context()
363
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
356
+ if knobs.compilation.enable_asan:
364
357
  raise RuntimeError(
365
358
  "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
366
359
  llvm_mod = llvm.to_module(mod, context)
367
360
  proc = sm_arch_from_capability(capability)
368
361
  features = get_features(options, self.target.arch)
369
362
  triple = 'nvptx64-nvidia-cuda'
363
+ nvidia.set_short_ptr()
370
364
  llvm.attach_datalayout(llvm_mod, triple, proc, features)
371
365
  nvidia.set_nvvm_reflect_ftz(llvm_mod)
372
366
 
373
- # Set maxnreg on all kernels, if it was provided.
374
- if options.maxnreg is not None:
375
- for k in llvm_mod.get_functions():
376
- if not k.is_declaration() and k.is_external_linkage():
377
- k.set_nvvm_maxnreg(options.maxnreg)
378
-
379
367
  if options.extern_libs:
380
368
  paths = [path for (name, path) in options.extern_libs]
381
369
  llvm.link_extern_libs(llvm_mod, paths)
@@ -402,7 +390,7 @@ class CUDABackend(BaseBackend):
402
390
  triple = 'nvptx64-nvidia-cuda'
403
391
  proc = sm_arch_from_capability(capability)
404
392
  features = get_features(opt, self.target.arch)
405
- ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
393
+ ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
406
394
  # Find kernel names (there should only be one)
407
395
  names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
408
396
  assert len(names) == 1
@@ -413,29 +401,38 @@ class CUDABackend(BaseBackend):
413
401
  ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
414
402
  # Remove the debug flag that prevents ptxas from optimizing the code
415
403
  ret = re.sub(r",\s*debug|debug,\s*", "", ret)
416
- if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
404
+ if knobs.nvidia.dump_nvptx:
417
405
  print("// -----// NVPTX Dump //----- //")
418
406
  print(ret)
419
407
  return ret
420
408
 
421
409
  def make_cubin(self, src, metadata, opt, capability):
422
- ptxas, _ = get_ptxas(self.target.arch)
410
+ ptxas = get_ptxas().path
423
411
  with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
424
412
  tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
425
413
  fsrc.write(src)
426
414
  fsrc.flush()
427
415
  fbin = fsrc.name + '.o'
428
416
 
429
- line_info = ["-lineinfo", "-suppress-debug-info"] if os.environ.get("TRITON_DISABLE_LINE_INFO",
430
- "0") == "1" else ["-lineinfo"]
417
+ line_info = ["-lineinfo", "-suppress-debug-info"] if knobs.compilation.disable_line_info else ["-lineinfo"]
431
418
  fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
432
419
  arch = sm_arch_from_capability(capability)
433
- opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
434
- ptxas_cmd = [ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name={arch}', fsrc.name, '-o', fbin]
420
+
421
+ # Disable ptxas optimizations if requested
422
+ disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
423
+
424
+ # Accept more ptxas options if provided
425
+ ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
426
+
427
+ ptxas_cmd = [
428
+ ptxas, *line_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name, '-o',
429
+ fbin
430
+ ]
435
431
  try:
436
432
  # close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
437
433
  # On Windows, both stdout and stderr need to be redirected to flog
438
- subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog, stderr=flog)
434
+ subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
435
+ stderr=flog)
439
436
  except subprocess.CalledProcessError as e:
440
437
  with open(flog.name) as log_file:
441
438
  log = log_file.read()
@@ -460,15 +457,18 @@ class CUDABackend(BaseBackend):
460
457
  try_remove(flog.name)
461
458
  return cubin
462
459
 
463
- def add_stages(self, stages, options):
460
+ def add_stages(self, stages, options, language):
464
461
  capability = self._parse_arch(options.arch)
465
- stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
466
- stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
462
+ if language == Language.TRITON:
463
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
464
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
465
+ elif language == Language.GLUON:
466
+ stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options, capability)
467
467
  stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
468
468
  stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
469
469
  stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
470
470
 
471
471
  @functools.lru_cache()
472
472
  def hash(self):
473
- version = get_ptxas_version(self.target.arch)
473
+ version = get_ptxas_version()
474
474
  return f'{version}-{self.target.arch}'
@@ -10,7 +10,6 @@
10
10
 
11
11
  #include <stdbool.h>
12
12
  #define PY_SSIZE_T_CLEAN
13
- #define Py_LIMITED_API 0x03090000
14
13
  #include <Python.h>
15
14
 
16
15
  // Raises a Python exception and returns false if code is not CUDA_SUCCESS.
@@ -112,6 +111,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
112
111
  CUmodule mod;
113
112
  int32_t n_regs = 0;
114
113
  int32_t n_spills = 0;
114
+ int32_t n_max_threads = 0;
115
115
  // create driver handles
116
116
  CUcontext pctx = 0;
117
117
 
@@ -132,6 +132,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
132
132
  CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
133
133
  cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
134
134
  n_spills /= 4;
135
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
136
+ &n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
135
137
  // set dynamic shared memory if necessary
136
138
  int shared_optin;
137
139
  CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
@@ -155,8 +157,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
155
157
  if (PyErr_Occurred()) {
156
158
  return NULL;
157
159
  }
158
- return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
159
- n_spills);
160
+ return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
161
+ n_spills, n_max_threads);
160
162
  }
161
163
 
162
164
  typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
@@ -308,112 +310,103 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
308
310
  return Py_None;
309
311
  }
310
312
 
311
- // Simple helper to experiment creating TMA descriptors on the host.
312
- // This is a useful to test TMA operations independently.
313
- static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
314
- unsigned long long global_address;
315
- uint64_t dim;
316
- uint32_t tensorDim;
317
- int elementSize;
313
+ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
318
314
  unsigned long long desc_address;
319
- if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim,
320
- &elementSize, &desc_address)) {
315
+ unsigned long long global_address;
316
+ int swizzle;
317
+ int elemSize;
318
+ int elemType;
319
+ PyObject *blockSize;
320
+ PyObject *shape;
321
+ PyObject *strides;
322
+
323
+ if (!PyArg_ParseTuple(args, "KKiiiOOO", &desc_address, &global_address,
324
+ &swizzle, &elemSize, &elemType, &blockSize, &shape,
325
+ &strides)) {
321
326
  return NULL;
322
327
  }
323
- uint64_t dims[1] = {dim};
324
- uint64_t globalStrides[1] = {dim * elementSize};
325
- uint32_t boxDim[1] = {tensorDim};
326
- uint32_t elementStrides[1] = {1};
327
- CUtensorMapDataType type;
328
- switch (elementSize) {
329
- case 1:
330
- type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
331
- break;
332
- case 2:
333
- type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
334
- break;
335
- case 4:
336
- type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
337
- break;
338
- default:
339
- PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
340
- return NULL;
328
+
329
+ PyObject *blockSizeFast = NULL;
330
+ PyObject *shapeFast = NULL;
331
+ PyObject *stridesFast = NULL;
332
+ PyObject *result = NULL;
333
+
334
+ uint32_t blockSizeInt[5];
335
+ uint64_t shapeInt[5];
336
+ uint64_t stridesLL[5];
337
+
338
+ blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
339
+ if (!blockSizeFast)
340
+ goto cleanup;
341
+ int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
342
+
343
+ for (int i = 0; i < rank; ++i) {
344
+ PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
345
+ if (!PyLong_Check(item)) {
346
+ PyErr_SetString(PyExc_TypeError, "block size must be an int");
347
+ goto cleanup;
348
+ }
349
+ blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
341
350
  }
342
- assert((elementSize * tensorDim) >= 32 && "block size too small.");
343
- int rank = 1;
344
- static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
345
- INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
346
- getCuTensorMapEncodeTiledHandle);
347
- CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
348
- (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
349
- globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
350
- CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE,
351
- CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
352
- Py_INCREF(Py_None);
353
- return Py_None;
354
- }
355
351
 
356
- // Simple helper to experiment creating TMA descriptors on the host.
357
- // This is a useful to test TMA operations independently.
358
- static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
359
- unsigned long long global_address;
360
- uint64_t dims[2];
361
- uint32_t tensorDims[2];
362
- int elementSize;
363
- unsigned long long desc_address;
364
- if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0],
365
- &tensorDims[1], &tensorDims[0], &elementSize,
366
- &desc_address)) {
367
- return NULL;
352
+ shapeFast = PySequence_Fast(shape, "shape must be a sequence");
353
+ if (!shapeFast)
354
+ goto cleanup;
355
+
356
+ if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
357
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
358
+ goto cleanup;
368
359
  }
369
- uint64_t globalStrides[2] = {dims[0] * elementSize,
370
- dims[0] * dims[1] * elementSize};
371
- uint32_t elementStrides[2] = {1, 1};
372
- CUtensorMapDataType type;
373
- switch (elementSize) {
374
- case 1:
375
- type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
376
- break;
377
- case 2:
378
- type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
379
- break;
380
- case 4:
381
- type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
382
- break;
383
- default:
384
- PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
360
+ for (int i = 0; i < rank; ++i) {
361
+ PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
362
+ if (!PyLong_Check(item)) {
363
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
364
+ goto cleanup;
365
+ }
366
+ shapeInt[rank - i - 1] = PyLong_AsLong(item);
385
367
  }
386
- int rank = 2;
387
- // Swizzling should be picked in codegen but since we need to set it on the
388
- // descriptor we rely on a convention between this function and codegen.
389
- CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
390
- uint32_t contigDimSizeInByte = elementSize * tensorDims[0];
391
- if (contigDimSizeInByte >= 128) {
392
- swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
393
- } else if (contigDimSizeInByte >= 64) {
394
- swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
395
- } else if (contigDimSizeInByte >= 32) {
396
- swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
397
- } else {
398
- assert(false && "block size too small.");
368
+
369
+ stridesFast = PySequence_Fast(strides, "strides must be a sequence");
370
+ if (!stridesFast)
371
+ goto cleanup;
372
+
373
+ if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
374
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
375
+ goto cleanup;
399
376
  }
400
- // The bounding box inner dimension must be less than or equal to the swizzle
401
- // size.
402
- // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
403
- // We clamp the block size and the codegen will emit multiple copy operations.
404
- if (contigDimSizeInByte > 128) {
405
- tensorDims[0] = 128 / elementSize;
377
+ for (int i = 0; i + 1 < rank; ++i) {
378
+ PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
379
+ if (!PyLong_Check(item)) {
380
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
381
+ goto cleanup;
382
+ }
383
+ stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
406
384
  }
385
+ stridesLL[rank - 1] =
386
+ shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
387
+ Py_DECREF(blockSizeFast);
388
+ blockSizeFast = NULL;
389
+ Py_DECREF(shapeFast);
390
+ shapeFast = NULL;
391
+ Py_DECREF(stridesFast);
392
+ stridesFast = NULL;
393
+
394
+ uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
407
395
  static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
408
396
  INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
409
397
  getCuTensorMapEncodeTiledHandle);
410
398
  CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
411
- (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
412
- globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
413
- swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
414
- CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
415
- Py_INCREF(Py_None);
416
- return Py_None;
399
+ (CUtensorMap *)desc_address, elemType, rank, (void *)global_address,
400
+ shapeInt, stridesLL, blockSizeInt, elementStrides,
401
+ CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
402
+ CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
403
+ Py_RETURN_NONE;
404
+
405
+ cleanup:
406
+ Py_XDECREF(blockSizeFast);
407
+ Py_XDECREF(shapeFast);
408
+ Py_XDECREF(stridesFast);
409
+ return result;
417
410
  }
418
411
 
419
412
  static PyMethodDef ModuleMethods[] = {
@@ -429,8 +422,7 @@ static PyMethodDef ModuleMethods[] = {
429
422
  "being dropped. This inherits all the limitations of this call; in "
430
423
  "particular it's an error to change this value after launching any kernel "
431
424
  "that calls printf()."},
432
- {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"},
433
- {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"},
425
+ {"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
434
426
 
435
427
  {NULL, NULL, 0, NULL} // sentinel
436
428
  };