triton-windows 3.2.0.post11__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,410 @@
1
+ from triton.backends.compiler import BaseBackend, GPUTarget
2
+ from triton._C.libtriton import ir, passes, llvm, nvidia
3
+
4
+ from dataclasses import dataclass
5
+ import functools
6
+ from typing import Any, Dict, Tuple, Optional
7
+ from types import ModuleType
8
+ import hashlib
9
+ import re
10
+ import tempfile
11
+ import signal
12
+ import os
13
+ import subprocess
14
+ from pathlib import Path
15
+
16
+
17
+ def min_dot_size(target: GPUTarget):
18
+ return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16)
19
+
20
+
21
+ @functools.lru_cache()
22
+ def _path_to_binary(binary: str):
23
+ if os.name == "nt":
24
+ binary += ".exe"
25
+ paths = [
26
+ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
27
+ os.path.join(os.path.dirname(__file__), "bin", binary),
28
+ ]
29
+ if os.name == "nt":
30
+ from triton.windows_utils import find_cuda
31
+ cuda_bin_path, _, _ = find_cuda()
32
+ if cuda_bin_path:
33
+ paths += [os.path.join(cuda_bin_path, binary)]
34
+
35
+ for bin in paths:
36
+ if os.path.exists(bin) and os.path.isfile(bin):
37
+ result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
38
+ if result is not None:
39
+ version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
40
+ if version is not None:
41
+ return bin, version.group(1)
42
+ raise RuntimeError(f"Cannot find {binary}")
43
+
44
+
45
+ @functools.lru_cache()
46
+ def get_ptxas_version():
47
+ version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8")
48
+ return version
49
+
50
+
51
+ @functools.lru_cache()
52
+ def ptx_get_version(cuda_version) -> int:
53
+ '''
54
+ Get the highest PTX version supported by the current CUDA driver.
55
+ '''
56
+ assert isinstance(cuda_version, str)
57
+ major, minor = map(int, cuda_version.split('.'))
58
+ if major == 12:
59
+ if minor < 6:
60
+ return 80 + minor
61
+ else:
62
+ return 79 + minor
63
+ if major == 11:
64
+ return 70 + minor
65
+ if major == 10:
66
+ return 63 + minor
67
+ raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
68
+
69
+
70
+ def get_ptx_version_from_options(options):
71
+ ptx_version = options.ptx_version
72
+ if ptx_version is None:
73
+ _, cuda_version = _path_to_binary("ptxas")
74
+ ptx_version = ptx_get_version(cuda_version)
75
+ return ptx_version
76
+
77
+
78
+ @functools.lru_cache()
79
+ def get_features(options):
80
+ ptx_version = get_ptx_version_from_options(options)
81
+
82
+ # PTX 8.3 is the max version supported by llvm 3a83162168.
83
+ #
84
+ # To check if a newer PTX version is supported, increase this value
85
+ # and run a test. If it's not supported, LLVM will print a warning
86
+ # like "+ptx8.4 is not a recognized feature for this target".
87
+ llvm_ptx_version = min(83, ptx_version)
88
+ features = f'+ptx{llvm_ptx_version}'
89
+ return features
90
+
91
+
92
+ @functools.lru_cache(None)
93
+ def file_hash(path):
94
+ with open(path, "rb") as f:
95
+ return hashlib.sha256(f.read()).hexdigest()
96
+
97
+
98
+ # The file may be accessed in parallel
99
+ def try_remove(path):
100
+ if os.path.exists(path):
101
+ try:
102
+ os.remove(path)
103
+ except OSError:
104
+ import traceback
105
+ traceback.print_exc()
106
+
107
+
108
+ @dataclass(frozen=True)
109
+ class CUDAOptions:
110
+ num_warps: int = 4
111
+ num_ctas: int = 1
112
+ num_stages: int = 3
113
+ num_buffers_warp_spec: int = 0
114
+ num_consumer_groups: int = 0
115
+ reg_dec_producer: int = 0
116
+ reg_inc_consumer: int = 0
117
+ # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
118
+ # maximum number of 32-bit registers used by one thread.
119
+ maxnreg: Optional[int] = None
120
+ cluster_dims: tuple = (1, 1, 1)
121
+ ptx_version: int = None
122
+ enable_fp_fusion: bool = True
123
+ supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
124
+ deprecated_fp8_dtypes: Tuple[str] = ()
125
+ default_dot_input_precision: str = "tf32"
126
+ allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
127
+ max_num_imprecise_acc_default: bool = None
128
+ extern_libs: dict = None
129
+ debug: bool = False
130
+ backend_name: str = 'cuda'
131
+ sanitize_overflow: bool = True
132
+
133
+ def __post_init__(self):
134
+ default_libdir = Path(__file__).parent / 'lib'
135
+ extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
136
+ if not extern_libs.get('libdevice', None):
137
+ extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc'))
138
+ object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
139
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
140
+ "num_warps must be a power of 2"
141
+
142
+ def hash(self):
143
+ hash_dict = dict(self.__dict__)
144
+ hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
145
+ key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
146
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
147
+
148
+
149
+ class CUDABackend(BaseBackend):
150
+
151
+ @staticmethod
152
+ def supports_target(target: GPUTarget):
153
+ return target.backend == 'cuda'
154
+
155
+ def __init__(self, target: GPUTarget) -> None:
156
+ super().__init__(target)
157
+ self.capability = target.arch
158
+ assert isinstance(self.capability, int)
159
+ self.binary_ext = "cubin"
160
+
161
+ def parse_options(self, opts) -> Any:
162
+ args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
163
+ if "supported_fp8_dtypes" not in args:
164
+ supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
165
+ if self.capability >= 89:
166
+ supported_fp8_dtypes.add("fp8e4nv")
167
+ args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
168
+
169
+ if "deprecated_fp8_dtypes" not in args:
170
+ if self.capability >= 90:
171
+ args["deprecated_fp8_dtypes"] = ("fp8e4b15", )
172
+
173
+ if "enable_fp_fusion" not in args:
174
+ args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
175
+ args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
176
+ return CUDAOptions(**args)
177
+
178
+ def pack_metadata(self, metadata):
179
+ return (
180
+ metadata.num_warps,
181
+ metadata.num_ctas,
182
+ metadata.shared,
183
+ metadata.cluster_dims[0],
184
+ metadata.cluster_dims[1],
185
+ metadata.cluster_dims[2],
186
+ )
187
+
188
+ def get_codegen_implementation(self):
189
+ import triton.language.extra.cuda as cuda
190
+ codegen_fns = {
191
+ "convert_custom_types":
192
+ cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70,
193
+ "min_dot_size": min_dot_size(self.target)
194
+ }
195
+ return codegen_fns
196
+
197
+ def get_module_map(self) -> Dict[str, ModuleType]:
198
+ from triton.language.extra.cuda import libdevice
199
+ return {"triton.language.extra.libdevice": libdevice}
200
+
201
+ def load_dialects(self, ctx):
202
+ nvidia.load_dialects(ctx)
203
+
204
+ @staticmethod
205
+ def make_ttir(mod, metadata, opt):
206
+ pm = ir.pass_manager(mod.context)
207
+ pm.enable_debug()
208
+ passes.common.add_inliner(pm)
209
+ passes.ttir.add_rewrite_tensor_pointer(pm)
210
+ passes.ttir.add_combine(pm)
211
+ passes.common.add_canonicalizer(pm)
212
+ passes.ttir.add_reorder_broadcast(pm)
213
+ passes.common.add_cse(pm)
214
+ passes.common.add_licm(pm)
215
+ passes.common.add_symbol_dce(pm)
216
+ passes.ttir.add_loop_unroll(pm)
217
+ pm.run(mod)
218
+ return mod
219
+
220
+ @staticmethod
221
+ def make_ttgir(mod, metadata, opt, capability):
222
+ cluster_info = nvidia.ClusterInfo()
223
+ if opt.cluster_dims is not None:
224
+ cluster_info.clusterDimX = opt.cluster_dims[0]
225
+ cluster_info.clusterDimY = opt.cluster_dims[1]
226
+ cluster_info.clusterDimZ = opt.cluster_dims[2]
227
+ # Set up Diagnostic
228
+ if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
229
+ srcMgr = llvm.source_mgr()
230
+ diag = ir.source_mgr_diag(srcMgr, mod.context)
231
+ mod.context.printOpOnDiagnostic(True)
232
+ # TTIR -> TTGIR
233
+ pm = ir.pass_manager(mod.context)
234
+ pm.enable_debug()
235
+ passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
236
+ # optimize TTGIR
237
+ passes.ttgpuir.add_coalesce(pm)
238
+ if capability // 10 >= 8:
239
+ passes.ttgpuir.add_f32_dot_tc(pm)
240
+ # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
241
+ nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
242
+ passes.ttgpuir.add_remove_layout_conversions(pm)
243
+ passes.ttgpuir.add_optimize_thread_locality(pm)
244
+ passes.ttgpuir.add_accelerate_matmul(pm)
245
+ passes.ttgpuir.add_remove_layout_conversions(pm)
246
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
247
+ passes.common.add_cse(pm)
248
+ if capability // 10 >= 8:
249
+ passes.ttgpuir.add_optimize_accumulator_init(pm)
250
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
251
+ passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
252
+ passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
253
+ passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
254
+ passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
255
+ opt.reg_dec_producer, opt.reg_inc_consumer)
256
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages)
257
+ passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
258
+ passes.ttgpuir.add_prefetch(pm)
259
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
260
+ passes.ttgpuir.add_remove_layout_conversions(pm)
261
+ passes.ttgpuir.add_reduce_data_duplication(pm)
262
+ passes.ttgpuir.add_reorder_instructions(pm)
263
+ passes.common.add_cse(pm)
264
+ passes.common.add_symbol_dce(pm)
265
+ if capability // 10 >= 9:
266
+ nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
267
+ nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
268
+ passes.common.add_canonicalizer(pm)
269
+ pm.run(mod)
270
+ metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
271
+ return mod
272
+
273
+ @staticmethod
274
+ def make_llir(src, metadata, options, capability):
275
+ ptx_version = get_ptx_version_from_options(options)
276
+
277
+ # warp-specialization mutates num_warps
278
+ num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
279
+ if num_warp_groups is not None:
280
+ metadata["num_warps"] *= num_warp_groups
281
+ mod = src
282
+ # TritonGPU -> LLVM-IR (MLIR)
283
+ pm = ir.pass_manager(mod.context)
284
+ pm.enable_debug()
285
+ # Set up Diagnostic
286
+ if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
287
+ srcMgr = llvm.source_mgr()
288
+ diag = ir.source_mgr_diag(srcMgr, mod.context)
289
+ mod.context.printOpOnDiagnostic(True)
290
+ nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
291
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
292
+ passes.convert.add_scf_to_cf(pm)
293
+ passes.convert.add_index_to_llvmir(pm)
294
+ passes.ttgpuir.add_allocate_shared_memory(pm)
295
+ nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
296
+ nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
297
+ passes.convert.add_arith_to_llvmir(pm)
298
+ passes.common.add_canonicalizer(pm)
299
+ passes.common.add_cse(pm)
300
+ passes.common.add_symbol_dce(pm)
301
+ if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
302
+ passes.llvmir.add_di_scope(pm)
303
+ pm.run(mod)
304
+ # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
305
+ llvm.init_targets()
306
+ context = llvm.context()
307
+
308
+ llvm_mod = llvm.to_module(mod, context)
309
+ proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
310
+ features = get_features(options)
311
+ triple = 'nvptx64-nvidia-cuda'
312
+ llvm.attach_datalayout(llvm_mod, triple, proc, features)
313
+ nvidia.set_nvvm_reflect_ftz(llvm_mod)
314
+
315
+ # Set maxnreg on all kernels, if it was provided.
316
+ if options.maxnreg is not None:
317
+ for k in llvm_mod.get_functions():
318
+ if not k.is_declaration() and k.is_external_linkage():
319
+ k.set_nvvm_maxnreg(options.maxnreg)
320
+
321
+ if options.extern_libs:
322
+ paths = [path for (name, path) in options.extern_libs]
323
+ llvm.link_extern_libs(llvm_mod, paths)
324
+
325
+ llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
326
+
327
+ # Get some metadata
328
+ metadata["shared"] = src.get_int_attr("triton_gpu.shared")
329
+ ret = str(llvm_mod)
330
+ del llvm_mod
331
+ del context
332
+ return ret
333
+
334
+ @staticmethod
335
+ def make_ptx(src, metadata, opt, capability):
336
+ ptx_version = get_ptx_version_from_options(opt)
337
+
338
+ triple = 'nvptx64-nvidia-cuda'
339
+ proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
340
+ features = get_features(opt)
341
+ ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
342
+ # Find kernel names (there should only be one)
343
+ names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
344
+ assert len(names) == 1
345
+ metadata["name"] = names[0]
346
+ # post-process
347
+ ptx_version = f'{ptx_version//10}.{ptx_version%10}'
348
+ ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
349
+ # Remove the debug flag that prevents ptxas from optimizing the code
350
+ ret = re.sub(r",\s*debug|debug,\s*", "", ret)
351
+ if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
352
+ print("// -----// NVPTX Dump //----- //")
353
+ print(ret)
354
+ return ret
355
+
356
+ @staticmethod
357
+ def make_cubin(src, metadata, opt, capability):
358
+ ptxas, _ = _path_to_binary("ptxas")
359
+ # On Windows, we need to set delete=False, close the temp file before reading it, and manually remove it
360
+ with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
361
+ tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
362
+ fsrc.write(src)
363
+ fsrc.close()
364
+ fbin = fsrc.name + '.o'
365
+
366
+ line_info = [] if os.environ.get('TRITON_DISABLE_LINE_INFO') else ['-lineinfo']
367
+ fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
368
+ suffix = 'a' if capability == 90 else ''
369
+ opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
370
+ ptxas_cmd = [
371
+ ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin
372
+ ]
373
+ try:
374
+ subprocess.run(ptxas_cmd, check=True, close_fds=False, stdout=flog, stderr=flog)
375
+ try_remove(fsrc.name)
376
+ flog.close()
377
+ try_remove(flog.name)
378
+ except subprocess.CalledProcessError as e:
379
+ flog.close()
380
+ with open(flog.name) as log_file:
381
+ log = log_file.read()
382
+ try_remove(flog.name)
383
+
384
+ if e.returncode == 255:
385
+ error = 'Internal Triton PTX codegen error'
386
+ elif e.returncode == 128 + signal.SIGSEGV:
387
+ error = '`ptxas` raised SIGSEGV'
388
+ else:
389
+ error = f'`ptxas` failed with error code {e.returncode}'
390
+
391
+ raise RuntimeError(f'{error}\n'
392
+ f'`ptxas` stderr:\n{log}\n'
393
+ f'Repro command: {" ".join(ptxas_cmd)}\n')
394
+
395
+ with open(fbin, 'rb') as f:
396
+ cubin = f.read()
397
+ try_remove(fbin)
398
+ return cubin
399
+
400
+ def add_stages(self, stages, options):
401
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
402
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
403
+ stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
404
+ stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
405
+ stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
406
+
407
+ @functools.lru_cache()
408
+ def hash(self):
409
+ version = get_ptxas_version()
410
+ return f'{version}-{self.capability}'