triton-windows 3.2.0.post11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,430 @@
1
+ from __future__ import annotations
2
+ import hashlib
3
+ import json
4
+ from .._C.libtriton import get_cache_invalidating_env_vars, ir
5
+ from ..backends import backends
6
+ from ..backends.compiler import GPUTarget, AttrsDescriptor
7
+ from .. import __version__
8
+ from ..runtime.autotuner import OutOfResources
9
+ from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
10
+ from ..runtime.driver import driver
11
+ from ..tools.disasm import get_sass
12
+ # TODO: this shouldn't be here
13
+ from .code_generator import ast_to_ttir
14
+ from pathlib import Path
15
+ import re
16
+ import functools
17
+ import os
18
+
19
+ # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
20
+ # and any following whitespace
21
+ # - (public\s+)? : optionally match the keyword public and any following whitespace
22
+ # - (@\w+) : match an @ symbol followed by one or more word characters
23
+ # (letters, digits, or underscores), and capture it as group 1 (the function name)
24
+ # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
25
+ # zero or more arguments separated by commas, and capture it as group 2 (the argument list)
26
+ # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
27
+ mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
28
+ ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
29
+ prototype_pattern = {
30
+ "ttir": mlir_prototype_pattern,
31
+ "ttgir": mlir_prototype_pattern,
32
+ "ptx": ptx_prototype_pattern,
33
+ }
34
+
35
+ mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
36
+ ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
37
+ arg_type_pattern = {
38
+ "ttir": mlir_arg_type_pattern,
39
+ "ttgir": mlir_arg_type_pattern,
40
+ "ptx": ptx_arg_type_pattern,
41
+ }
42
+
43
+
44
+ def convert_type_repr(x):
45
+ # Currently we only capture the pointer type and assume the pointer is on global memory.
46
+ # TODO: Capture and support shared memory space
47
+ match = re.search(r'!tt\.ptr<([^,]+)', x)
48
+ tma = re.search(r'tt.nv_tma_desc = 1', x)
49
+ if tma is not None:
50
+ return 'nvTmaDesc'
51
+ x = re.sub(r' {[^}]+}', '', x)
52
+ if match is not None:
53
+ return '*' + convert_type_repr(match.group(1))
54
+ return x
55
+
56
+
57
+ def _get_num_warps_from_ir_str(src: str):
58
+ ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
59
+ # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
60
+ # e.g. someone has an instruction (not module) attribute named "num-warps".
61
+ num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
62
+ assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
63
+ num_warps = int(num_warps_matches[0])
64
+ return num_warps
65
+
66
+
67
+ class ASTSource:
68
+
69
+ def __init__(self, fn, signature, constants=None, attrs=None) -> None:
70
+ self.fn = fn
71
+ self.ext = "ttir"
72
+ self.name = fn.__name__
73
+ self.signature = signature
74
+ self.constants = constants
75
+ self.attrs = attrs
76
+ if isinstance(self.signature, str):
77
+ self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
78
+ else:
79
+ for k in self.signature.keys():
80
+ if not isinstance(k, str):
81
+ raise TypeError("Signature keys must be string")
82
+ if self.constants is None:
83
+ self.constants = {}
84
+ else:
85
+ for k in self.constants.keys():
86
+ if not isinstance(k, str):
87
+ raise TypeError("Constants keys must be string")
88
+ if self.attrs is None:
89
+ self.attrs = AttrsDescriptor()
90
+
91
+ def hash(self):
92
+ sorted_sig = [v for k, v in sorted(self.signature.items())]
93
+ # Note - we stringify the keys here to allow sorting to work for cases
94
+ # where constants have mixed int/str keys.
95
+ sorted_constants = sorted((str(k), v) for k, v in self.constants.items())
96
+ key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
97
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
98
+
99
+ def make_ir(self, options, codegen_fns, module_map, context):
100
+ return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
101
+ module_map=module_map)
102
+
103
+ def parse_options(self):
104
+ return dict()
105
+
106
+
107
+ class IRSource:
108
+
109
+ def __init__(self, path):
110
+ self.path = path
111
+ path = Path(path)
112
+ self.ext = path.suffix[1:]
113
+ self.src = path.read_text()
114
+ match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
115
+ self.name = match.group(1)
116
+ signature = match.group(2)
117
+ types = re.findall(arg_type_pattern[self.ext], signature)
118
+ self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
119
+
120
+ def hash(self):
121
+ return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
122
+
123
+ def make_ir(self, options, codegen_fns, module_map, context):
124
+ module = ir.parse_mlir_module(self.path, context)
125
+ module.context = context
126
+ return module
127
+
128
+ def parse_options(self):
129
+ if self.ext == "ttgir":
130
+ return {'num_warps': _get_num_warps_from_ir_str(self.src)}
131
+ return dict()
132
+
133
+
134
+ @functools.lru_cache()
135
+ def triton_key():
136
+ import pkgutil
137
+ TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
138
+ contents = []
139
+ # frontend
140
+ with open(__file__, "rb") as f:
141
+ contents += [hashlib.sha256(f.read()).hexdigest()]
142
+ # compiler
143
+ path_prefixes = [
144
+ (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
145
+ (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
146
+ ]
147
+ for path, prefix in path_prefixes:
148
+ for lib in pkgutil.walk_packages([path], prefix=prefix):
149
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
150
+ contents += [hashlib.sha256(f.read()).hexdigest()]
151
+
152
+ # backend
153
+ libtriton_hash = hashlib.sha256()
154
+ if os.name == "nt":
155
+ so_name = "libtriton.pyd"
156
+ else:
157
+ so_name = "libtriton.so"
158
+ with open(os.path.join(TRITON_PATH, f"_C/{so_name}"), "rb") as f:
159
+ while True:
160
+ chunk = f.read(1024**2)
161
+ if not chunk:
162
+ break
163
+ libtriton_hash.update(chunk)
164
+ contents.append(libtriton_hash.hexdigest())
165
+ # language
166
+ language_path = os.path.join(TRITON_PATH, 'language')
167
+ for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
168
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
169
+ contents += [hashlib.sha256(f.read()).hexdigest()]
170
+ return f'{__version__}' + '-'.join(contents)
171
+
172
+
173
+ def parse(full_name, ext, context):
174
+ if ext == "ttir" or ext == "ttgir":
175
+ module = ir.parse_mlir_module(full_name, context)
176
+ module.context = context
177
+ return module
178
+ if ext == "llir" or ext == "ptx":
179
+ return Path(full_name).read_text()
180
+ if ext == "cubin":
181
+ return Path(full_name).read_bytes()
182
+
183
+
184
+ def filter_traceback(e: BaseException):
185
+ """
186
+ Removes code_generator.py and related files from tracebacks.
187
+
188
+ These are uninteresting to the user -- "just show me *my* code!"
189
+ """
190
+ if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1":
191
+ return
192
+
193
+ if e.__cause__ is not None:
194
+ filter_traceback(e.__cause__)
195
+ if e.__context__ is not None:
196
+ filter_traceback(e.__context__)
197
+
198
+ # If a user has a file that matches one of these, they're out of luck.
199
+ BAD_FILES = [
200
+ "/triton/compiler/code_generator.py",
201
+ "/ast.py",
202
+ ]
203
+
204
+ tb = e.__traceback__
205
+ frames = []
206
+ while tb is not None:
207
+ if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)):
208
+ frames.append(tb)
209
+ tb = tb.tb_next
210
+
211
+ for (cur_frame, next_frame) in zip(frames, frames[1:]):
212
+ cur_frame.tb_next = next_frame
213
+
214
+ if not frames:
215
+ e.__traceback__ = None
216
+ else:
217
+ frames[-1].tb_next = None
218
+ e.__traceback__ = frames[0]
219
+
220
+
221
+ def compile(src, target=None, options=None):
222
+ if target is None:
223
+ target = driver.active.get_current_target()
224
+ assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
225
+ backend = make_backend(target)
226
+ ir_source = not isinstance(src, ASTSource)
227
+ # create backend
228
+ if ir_source:
229
+ assert isinstance(src, str), "source must be either AST or a filepath"
230
+ src = IRSource(src)
231
+ extra_options = src.parse_options()
232
+ options = backend.parse_options(dict(options or dict(), **extra_options))
233
+ # create cache manager
234
+ env_vars = get_cache_invalidating_env_vars()
235
+ key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}"
236
+ hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
237
+ fn_cache_manager = get_cache_manager(hash)
238
+ # For dumping/overriding only hash the source as we want it to be independent of triton
239
+ # core changes to make it easier to track kernels by hash.
240
+ enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
241
+ enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
242
+ fn_override_manager = get_override_manager(src.hash()) if enable_override else None
243
+ fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
244
+ # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
245
+ # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}".
246
+ # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate
247
+ # the file name to 150 characters to be safe.
248
+ file_name = src.name[:150]
249
+ metadata_filename = f"{file_name}.json"
250
+ metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
251
+ metadata_path = metadata_group.get(metadata_filename)
252
+ always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1"
253
+ if not always_compile and metadata_path is not None:
254
+ # cache hit!
255
+ metadata = json.loads(Path(metadata_path).read_text())
256
+ return CompiledKernel(src, metadata_group, hash)
257
+ # initialize metadata
258
+ metadata = {
259
+ "hash": hash,
260
+ "target": target,
261
+ **options.__dict__,
262
+ **env_vars,
263
+ }
264
+ # run compilation pipeline and populate metadata
265
+ stages = dict()
266
+ backend.add_stages(stages, options)
267
+ first_stage = list(stages.keys()).index(src.ext)
268
+ # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
269
+ if ir_source:
270
+ first_stage += 1
271
+ context = ir.context()
272
+ ir.load_dialects(context)
273
+ backend.load_dialects(context)
274
+ codegen_fns = backend.get_codegen_implementation()
275
+ module_map = backend.get_module_map()
276
+ try:
277
+ module = src.make_ir(options, codegen_fns, module_map, context)
278
+ except Exception as e:
279
+ filter_traceback(e)
280
+ raise
281
+ use_ir_loc = os.environ.get("USE_IR_LOC", None)
282
+ for ext, compile_ir in list(stages.items())[first_stage:]:
283
+ next_module = compile_ir(module, metadata)
284
+ ir_filename = f"{file_name}.{ext}"
285
+ if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
286
+ print(f"\nOverriding kernel with file {full_name}")
287
+ next_module = parse(full_name, ext, context)
288
+ metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
289
+ if fn_dump_manager is not None:
290
+ fn_dump_manager.put(next_module, ir_filename)
291
+ # use an env variable to parse ir from file
292
+ if use_ir_loc == ext:
293
+ ir_full_name = fn_cache_manager.get_file(ir_filename)
294
+ next_module.create_location_snapshot(ir_full_name)
295
+ print(f"Creating new locations for {ir_full_name}")
296
+ module = next_module
297
+ # write-back metadata
298
+ metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
299
+ binary=False)
300
+ fn_cache_manager.put_group(metadata_filename, metadata_group)
301
+ # Compilation completed, disabling multithreading in context.
302
+ # This is needed to safely finalize threads pool inside context: if current process forks before
303
+ # python GC deletes context object, thread pool in child process will be invalid, which could
304
+ # lead to child crash or hang.
305
+ context.disable_multithreading()
306
+ # return handle to compiled kernel
307
+ return CompiledKernel(src, metadata_group, hash)
308
+
309
+
310
+ def make_backend(target):
311
+ actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
312
+ if len(actives) != 1:
313
+ raise RuntimeError(
314
+ f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
315
+ return actives[0](target)
316
+
317
+
318
+ class LazyDict:
319
+
320
+ def __init__(self, data):
321
+ self.data = data
322
+ self.extras = []
323
+
324
+ def get(self) -> None:
325
+ for func, args in self.extras:
326
+ self.data = self.data | func(*args)
327
+ self.extras.clear()
328
+ return self.data
329
+
330
+ def add(self, func, args):
331
+ self.extras.append((func, args))
332
+
333
+
334
+ class AsmDict(dict):
335
+
336
+ def __missing__(self, key):
337
+
338
+ if key == "sass":
339
+ value = get_sass(self["cubin"])
340
+ else:
341
+ raise KeyError("Unknown key: '%s'" % key)
342
+
343
+ self[key] = value
344
+ return value
345
+
346
+
347
+ class CompiledKernel:
348
+
349
+ # Hooks for external tools to monitor the execution of triton kernels
350
+ # TODO: move out of this namespace since it's a runtime thing
351
+ launch_enter_hook = None
352
+ launch_exit_hook = None
353
+
354
+ def __init__(self, src, metadata_group, hash):
355
+ from collections import namedtuple
356
+ metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
357
+ metadata = json.loads(metadata_path.read_text())
358
+ metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
359
+ # JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
360
+ target = metadata['target']
361
+ metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
362
+ KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
363
+ self.metadata = KernelMetadata(**metadata)
364
+ backend = make_backend(self.metadata.target)
365
+ self.packed_metadata = backend.pack_metadata(self.metadata)
366
+ self.src = src
367
+ self.hash = hash
368
+ self.name = self.metadata.name
369
+ # stores the text of each level of IR that was generated during compilation
370
+ asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
371
+ binary_ext = backend.binary_ext
372
+ self.asm = AsmDict({
373
+ file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
374
+ for file in asm_files
375
+ })
376
+ self.kernel = self.asm[binary_ext]
377
+ # binaries are lazily initialized
378
+ # because it involves doing runtime things
379
+ # (e.g., checking amount of shared memory on current device)
380
+ self.module = None
381
+ self.function = None
382
+
383
+ def _init_handles(self):
384
+ if self.module is not None:
385
+ return
386
+ device = driver.active.get_current_device()
387
+ # create launcher
388
+ self.run = driver.active.launcher_cls(self.src, self.metadata)
389
+ # not enough shared memory to run the kernel
390
+ max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
391
+ if self.metadata.shared > max_shared:
392
+ raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
393
+ # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
394
+ self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
395
+ self.name, self.kernel, self.metadata.shared, device)
396
+
397
+ def __getattribute__(self, name):
398
+ if name == 'run':
399
+ self._init_handles()
400
+ return super().__getattribute__(name)
401
+
402
+ def launch_metadata(self, grid, stream, *args):
403
+ if CompiledKernel.launch_enter_hook is None:
404
+ return None
405
+ ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
406
+ if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
407
+ return ret
408
+ arg_dict = {}
409
+ arg_idx = 0
410
+ for i, arg_name in enumerate(self.src.fn.arg_names):
411
+ if i in self.src.fn.constexprs:
412
+ arg_dict[arg_name] = self.src.constants[arg_name]
413
+ else:
414
+ arg_dict[arg_name] = args[arg_idx]
415
+ arg_idx += 1
416
+ ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
417
+ return ret
418
+
419
+ def __getitem__(self, grid):
420
+ self._init_handles()
421
+
422
+ def runner(*args, stream=None):
423
+ if stream is None:
424
+ device = driver.active.get_current_device()
425
+ stream = driver.active.get_current_stream(device)
426
+ launch_metadata = self.launch_metadata(grid, stream, *args)
427
+ self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
428
+ CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
429
+
430
+ return runner
@@ -0,0 +1,51 @@
1
+ import ast
2
+ from typing import Optional
3
+ from ..errors import TritonError
4
+
5
+
6
+ class CompilationError(TritonError):
7
+ """Base class for all errors raised during compilation"""
8
+ source_line_count_max_in_message = 12
9
+
10
+ def _format_message(self) -> str:
11
+ node = self.node
12
+ if self.src is None:
13
+ source_excerpt = " <source unavailable>"
14
+ else:
15
+ if hasattr(node, 'lineno'):
16
+ source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
17
+ if source_excerpt:
18
+ source_excerpt.append(' ' * node.col_offset + '^')
19
+ source_excerpt = '\n'.join(source_excerpt)
20
+ else:
21
+ source_excerpt = " <source empty>"
22
+ else:
23
+ source_excerpt = self.src
24
+
25
+ message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr(
26
+ node, 'lineno') else source_excerpt
27
+ if self.error_message:
28
+ message += '\n' + self.error_message
29
+ return message
30
+
31
+ def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None):
32
+ self.src = src
33
+ self.node = node
34
+ self.error_message = error_message
35
+ self.message = self._format_message()
36
+
37
+ def __str__(self):
38
+ return self.message
39
+
40
+ def __reduce__(self):
41
+ # this is necessary to make CompilationError picklable
42
+ return type(self), (self.src, self.node, self.error_message)
43
+
44
+
45
+ class CompileTimeAssertionFailure(CompilationError):
46
+ """Specific exception for failed tests in `static_assert` invocations"""
47
+ pass
48
+
49
+
50
+ class UnsupportedLanguageConstruct(CompilationError):
51
+ pass
File without changes
triton/errors.py ADDED
@@ -0,0 +1,5 @@
1
+ """Base class for all errors raised by Triton"""
2
+
3
+
4
+ class TritonError(Exception):
5
+ ...