triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +73 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +262 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +497 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +76 -0
- triton/backends/driver.py +34 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +347 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +430 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1302 -0
- triton/compiler/compiler.py +416 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +284 -0
- triton/language/core.py +2621 -0
- triton/language/extra/__init__.py +4 -0
- triton/language/extra/cuda/__init__.py +8 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +468 -0
- triton/language/extra/libdevice.py +1213 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1621 -0
- triton/language/standard.py +441 -0
- triton/ops/__init__.py +7 -0
- triton/ops/blocksparse/__init__.py +7 -0
- triton/ops/blocksparse/matmul.py +432 -0
- triton/ops/blocksparse/softmax.py +228 -0
- triton/ops/cross_entropy.py +96 -0
- triton/ops/flash_attention.py +466 -0
- triton/ops/matmul.py +219 -0
- triton/ops/matmul_perf_model.py +171 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +361 -0
- triton/runtime/build.py +129 -0
- triton/runtime/cache.py +289 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1127 -0
- triton/runtime/jit.py +956 -0
- triton/runtime/tcc/include/_mingw.h +170 -0
- triton/runtime/tcc/include/assert.h +57 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +57 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/limits.h +111 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +737 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdarg.h +79 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +54 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +580 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +118 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/winbase.h +2951 -0
- triton/runtime/tcc/include/winapi/wincon.h +301 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnt.h +5835 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +496 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +145 -0
- triton/tools/disasm.py +142 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +373 -0
- triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
- triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
- triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
- triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import hashlib
|
|
3
|
+
import json
|
|
4
|
+
from .._C.libtriton import get_cache_invalidating_env_vars, ir
|
|
5
|
+
from ..backends import backends
|
|
6
|
+
from ..backends.compiler import GPUTarget
|
|
7
|
+
from .. import __version__
|
|
8
|
+
from ..runtime.autotuner import OutOfResources
|
|
9
|
+
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
|
10
|
+
from ..runtime.driver import driver
|
|
11
|
+
# TODO: this shouldn't be here
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from .code_generator import ast_to_ttir
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
import re
|
|
16
|
+
import functools
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class AttrsDescriptor:
|
|
22
|
+
divisible_by_16: set = None
|
|
23
|
+
equal_to_1: set = None
|
|
24
|
+
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
if self.divisible_by_16 is None:
|
|
27
|
+
self.divisible_by_16 = set()
|
|
28
|
+
if self.equal_to_1 is None:
|
|
29
|
+
self.equal_to_1 = set()
|
|
30
|
+
|
|
31
|
+
def to_dict(self):
|
|
32
|
+
return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)}
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def from_dict(data):
|
|
36
|
+
return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])),
|
|
37
|
+
equal_to_1=set(data.get('equal_to_1', [])))
|
|
38
|
+
|
|
39
|
+
def hash(self):
|
|
40
|
+
key = str([sorted(x) for x in self.__dict__.values()])
|
|
41
|
+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
45
|
+
# and any following whitespace
|
|
46
|
+
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
|
47
|
+
# - (@\w+) : match an @ symbol followed by one or more word characters
|
|
48
|
+
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
|
49
|
+
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
|
50
|
+
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
|
51
|
+
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
|
|
52
|
+
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
|
|
53
|
+
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
|
54
|
+
prototype_pattern = {
|
|
55
|
+
"ttir": mlir_prototype_pattern,
|
|
56
|
+
"ttgir": mlir_prototype_pattern,
|
|
57
|
+
"ptx": ptx_prototype_pattern,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?'
|
|
61
|
+
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
|
62
|
+
arg_type_pattern = {
|
|
63
|
+
"ttir": mlir_arg_type_pattern,
|
|
64
|
+
"ttgir": mlir_arg_type_pattern,
|
|
65
|
+
"ptx": ptx_arg_type_pattern,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def convert_type_repr(x):
|
|
70
|
+
# Currently we only capture the pointer type and assume the pointer is on global memory.
|
|
71
|
+
# TODO: Capture and support shared memory space
|
|
72
|
+
match = re.search(r'!tt\.ptr<([^,]+)', x)
|
|
73
|
+
if match is not None:
|
|
74
|
+
return '*' + convert_type_repr(match.group(1))
|
|
75
|
+
return x
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_num_warps_from_ir_str(src: str):
|
|
79
|
+
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
|
80
|
+
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
|
|
81
|
+
# e.g. someone has an instruction (not module) attribute named "num-warps".
|
|
82
|
+
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
|
83
|
+
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
|
84
|
+
num_warps = int(num_warps_matches[0])
|
|
85
|
+
return num_warps
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ASTSource:
|
|
89
|
+
|
|
90
|
+
def __init__(self, fn, signature, constants=None, attrs=None) -> None:
|
|
91
|
+
self.fn = fn
|
|
92
|
+
self.ext = "ttir"
|
|
93
|
+
self.name = fn.__name__
|
|
94
|
+
self.signature = signature
|
|
95
|
+
self.constants = constants
|
|
96
|
+
self.attrs = attrs
|
|
97
|
+
if isinstance(self.signature, str):
|
|
98
|
+
self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
|
|
99
|
+
if self.constants is None:
|
|
100
|
+
self.constants = dict()
|
|
101
|
+
if self.attrs is None:
|
|
102
|
+
self.attrs = AttrsDescriptor()
|
|
103
|
+
|
|
104
|
+
def hash(self):
|
|
105
|
+
sorted_sig = [v for k, v in sorted(self.signature.items())]
|
|
106
|
+
# Note - we stringify the keys here to allow sorting to work for cases
|
|
107
|
+
# where constants have mixed int/str keys.
|
|
108
|
+
sorted_constants = sorted((str(k), v) for k, v in self.constants.items())
|
|
109
|
+
key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
|
|
110
|
+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
111
|
+
|
|
112
|
+
def make_ir(self, options, codegen_fns, context):
|
|
113
|
+
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
|
|
114
|
+
|
|
115
|
+
def parse_options(self):
|
|
116
|
+
return dict()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class IRSource:
|
|
120
|
+
|
|
121
|
+
def __init__(self, path):
|
|
122
|
+
self.path = path
|
|
123
|
+
path = Path(path)
|
|
124
|
+
self.ext = path.suffix[1:]
|
|
125
|
+
self.src = path.read_text()
|
|
126
|
+
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
|
|
127
|
+
self.name = match.group(1)
|
|
128
|
+
signature = match.group(2)
|
|
129
|
+
types = re.findall(arg_type_pattern[self.ext], signature)
|
|
130
|
+
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
|
|
131
|
+
|
|
132
|
+
def hash(self):
|
|
133
|
+
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
|
|
134
|
+
|
|
135
|
+
def make_ir(self, options, codegen_fns, context):
|
|
136
|
+
module = ir.parse_mlir_module(self.path, context)
|
|
137
|
+
module.context = context
|
|
138
|
+
return module
|
|
139
|
+
|
|
140
|
+
def parse_options(self):
|
|
141
|
+
if self.ext == "ttgir":
|
|
142
|
+
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
|
|
143
|
+
return dict()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@functools.lru_cache()
|
|
147
|
+
def triton_key():
|
|
148
|
+
import pkgutil
|
|
149
|
+
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
150
|
+
contents = []
|
|
151
|
+
# frontend
|
|
152
|
+
with open(__file__, "rb") as f:
|
|
153
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
154
|
+
# compiler
|
|
155
|
+
path_prefixes = [
|
|
156
|
+
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
|
|
157
|
+
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
|
|
158
|
+
]
|
|
159
|
+
for path, prefix in path_prefixes:
|
|
160
|
+
for lib in pkgutil.walk_packages([path], prefix=prefix):
|
|
161
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
162
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
163
|
+
|
|
164
|
+
# backend
|
|
165
|
+
libtriton_hash = hashlib.sha256()
|
|
166
|
+
if os.name == "nt":
|
|
167
|
+
so_name = "libtriton.pyd"
|
|
168
|
+
else:
|
|
169
|
+
so_name = "libtriton.so"
|
|
170
|
+
with open(os.path.join(TRITON_PATH, f"_C/{so_name}"), "rb") as f:
|
|
171
|
+
while True:
|
|
172
|
+
chunk = f.read(1024**2)
|
|
173
|
+
if not chunk:
|
|
174
|
+
break
|
|
175
|
+
libtriton_hash.update(chunk)
|
|
176
|
+
contents.append(libtriton_hash.hexdigest())
|
|
177
|
+
# language
|
|
178
|
+
language_path = os.path.join(TRITON_PATH, 'language')
|
|
179
|
+
for lib in pkgutil.iter_modules([language_path]):
|
|
180
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
181
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
182
|
+
return f'{__version__}' + '-'.join(contents)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def parse(full_name, ext, context):
|
|
186
|
+
if ext == "ttir" or ext == "ttgir":
|
|
187
|
+
module = ir.parse_mlir_module(full_name, context)
|
|
188
|
+
module.context = context
|
|
189
|
+
return module
|
|
190
|
+
if ext == "llir" or ext == "ptx":
|
|
191
|
+
return Path(full_name).read_text()
|
|
192
|
+
if ext == "cubin":
|
|
193
|
+
return Path(full_name).read_bytes()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def filter_traceback(e: BaseException):
|
|
197
|
+
"""
|
|
198
|
+
Removes code_generator.py and related files from tracebacks.
|
|
199
|
+
|
|
200
|
+
These are uninteresting to the user -- "just show me *my* code!"
|
|
201
|
+
"""
|
|
202
|
+
if e.__cause__ is not None:
|
|
203
|
+
filter_traceback(e.__cause__)
|
|
204
|
+
if e.__context__ is not None:
|
|
205
|
+
filter_traceback(e.__context__)
|
|
206
|
+
|
|
207
|
+
# If a user has a file that matches one of these, they're out of luck.
|
|
208
|
+
BAD_FILES = [
|
|
209
|
+
"/triton/compiler/code_generator.py",
|
|
210
|
+
"/ast.py",
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
tb = e.__traceback__
|
|
214
|
+
frames = []
|
|
215
|
+
while tb is not None:
|
|
216
|
+
if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)):
|
|
217
|
+
frames.append(tb)
|
|
218
|
+
tb = tb.tb_next
|
|
219
|
+
|
|
220
|
+
for (cur_frame, next_frame) in zip(frames, frames[1:]):
|
|
221
|
+
cur_frame.tb_next = next_frame
|
|
222
|
+
|
|
223
|
+
if not frames:
|
|
224
|
+
e.__traceback__ = None
|
|
225
|
+
else:
|
|
226
|
+
frames[-1].tb_next = None
|
|
227
|
+
e.__traceback__ = frames[0]
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def compile(src, target=None, options=None):
|
|
231
|
+
if target is None:
|
|
232
|
+
target = driver.active.get_current_target()
|
|
233
|
+
assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
|
|
234
|
+
backend = make_backend(target)
|
|
235
|
+
ir_source = not isinstance(src, ASTSource)
|
|
236
|
+
# create backend
|
|
237
|
+
if ir_source:
|
|
238
|
+
assert isinstance(src, str), "source must be either AST or a filepath"
|
|
239
|
+
src = IRSource(src)
|
|
240
|
+
extra_options = src.parse_options()
|
|
241
|
+
options = backend.parse_options(dict(options or dict(), **extra_options))
|
|
242
|
+
# create cache manager
|
|
243
|
+
env_vars = get_cache_invalidating_env_vars()
|
|
244
|
+
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}"
|
|
245
|
+
hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
246
|
+
fn_cache_manager = get_cache_manager(hash)
|
|
247
|
+
# For dumping/overriding only hash the source as we want it to be independent of triton
|
|
248
|
+
# core changes to make it easier to track kernels by hash.
|
|
249
|
+
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
|
250
|
+
enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
|
|
251
|
+
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
|
|
252
|
+
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
|
|
253
|
+
metadata_filename = f"{src.name}.json"
|
|
254
|
+
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
|
255
|
+
metadata_path = metadata_group.get(metadata_filename)
|
|
256
|
+
always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1"
|
|
257
|
+
if not always_compile and metadata_path is not None:
|
|
258
|
+
# cache hit!
|
|
259
|
+
metadata = json.loads(Path(metadata_path).read_text())
|
|
260
|
+
return CompiledKernel(src, metadata_group, hash)
|
|
261
|
+
# initialize metadata
|
|
262
|
+
metadata = {
|
|
263
|
+
"hash": hash,
|
|
264
|
+
"target": target,
|
|
265
|
+
**options.__dict__,
|
|
266
|
+
**env_vars,
|
|
267
|
+
}
|
|
268
|
+
# run compilation pipeline and populate metadata
|
|
269
|
+
stages = dict()
|
|
270
|
+
backend.add_stages(stages, options)
|
|
271
|
+
first_stage = list(stages.keys()).index(src.ext)
|
|
272
|
+
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
|
|
273
|
+
if ir_source:
|
|
274
|
+
first_stage += 1
|
|
275
|
+
context = ir.context()
|
|
276
|
+
ir.load_dialects(context)
|
|
277
|
+
backend.load_dialects(context)
|
|
278
|
+
codegen_fns = backend.get_codegen_implementation()
|
|
279
|
+
try:
|
|
280
|
+
module = src.make_ir(options, codegen_fns, context)
|
|
281
|
+
except Exception as e:
|
|
282
|
+
filter_traceback(e)
|
|
283
|
+
raise
|
|
284
|
+
use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1"
|
|
285
|
+
for ext, compile_ir in list(stages.items())[first_stage:]:
|
|
286
|
+
next_module = compile_ir(module, metadata)
|
|
287
|
+
ir_filename = f"{src.name}.{ext}"
|
|
288
|
+
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
|
289
|
+
if fn_dump_manager is not None:
|
|
290
|
+
fn_dump_manager.put(next_module, ir_filename)
|
|
291
|
+
if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)):
|
|
292
|
+
print(f"\nOverriding kernel with file {ir_filename}")
|
|
293
|
+
full_name = fn_override_manager.get_file(ir_filename)
|
|
294
|
+
next_module = parse(full_name, ext, context)
|
|
295
|
+
# use an env variable to parse ttgir from file
|
|
296
|
+
if use_ttgir_loc and ext == "ttgir":
|
|
297
|
+
ttgir_full_name = fn_cache_manager.get_file(ir_filename)
|
|
298
|
+
next_module.create_location_snapshot(ttgir_full_name)
|
|
299
|
+
print(f"Create new locations for {ttgir_full_name}")
|
|
300
|
+
module = next_module
|
|
301
|
+
# write-back metadata
|
|
302
|
+
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
|
303
|
+
binary=False)
|
|
304
|
+
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
|
305
|
+
# return handle to compiled kernel
|
|
306
|
+
return CompiledKernel(src, metadata_group, hash)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def make_backend(target):
|
|
310
|
+
actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
|
|
311
|
+
if len(actives) != 1:
|
|
312
|
+
raise RuntimeError(
|
|
313
|
+
f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
|
|
314
|
+
return actives[0](target)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class LazyDict:
|
|
318
|
+
|
|
319
|
+
def __init__(self, data):
|
|
320
|
+
self.data = data
|
|
321
|
+
self.extras = []
|
|
322
|
+
|
|
323
|
+
def get(self) -> None:
|
|
324
|
+
for func, args in self.extras:
|
|
325
|
+
self.data = self.data | func(*args)
|
|
326
|
+
self.extras.clear()
|
|
327
|
+
return self.data
|
|
328
|
+
|
|
329
|
+
def add(self, func, args):
|
|
330
|
+
self.extras.append((func, args))
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class CompiledKernel:
|
|
334
|
+
|
|
335
|
+
# Hooks for external tools to monitor the execution of triton kernels
|
|
336
|
+
# TODO: move out of this namespace since it's a runtime thing
|
|
337
|
+
launch_enter_hook = None
|
|
338
|
+
launch_exit_hook = None
|
|
339
|
+
|
|
340
|
+
def __init__(self, src, metadata_group, hash):
|
|
341
|
+
from collections import namedtuple
|
|
342
|
+
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
|
|
343
|
+
metadata = json.loads(metadata_path.read_text())
|
|
344
|
+
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
|
|
345
|
+
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
|
|
346
|
+
target = metadata['target']
|
|
347
|
+
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
|
|
348
|
+
KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
|
|
349
|
+
self.metadata = KernelMetadata(**metadata)
|
|
350
|
+
backend = make_backend(self.metadata.target)
|
|
351
|
+
self.packed_metadata = backend.pack_metadata(self.metadata)
|
|
352
|
+
self.src = src
|
|
353
|
+
self.hash = hash
|
|
354
|
+
self.name = self.metadata.name
|
|
355
|
+
# stores the text of each level of IR that was generated during compilation
|
|
356
|
+
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
|
|
357
|
+
binary_ext = backend.binary_ext
|
|
358
|
+
self.asm = {
|
|
359
|
+
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
|
|
360
|
+
for file in asm_files
|
|
361
|
+
}
|
|
362
|
+
self.kernel = self.asm[binary_ext]
|
|
363
|
+
# binaries are lazily initialized
|
|
364
|
+
# because it involves doing runtime things
|
|
365
|
+
# (e.g., checking amount of shared memory on current device)
|
|
366
|
+
self.module = None
|
|
367
|
+
self.function = None
|
|
368
|
+
|
|
369
|
+
def _init_handles(self):
|
|
370
|
+
if self.module is not None:
|
|
371
|
+
return
|
|
372
|
+
device = driver.active.get_current_device()
|
|
373
|
+
# create launcher
|
|
374
|
+
self.run = driver.active.launcher_cls(self.src, self.metadata)
|
|
375
|
+
# not enough shared memory to run the kernel
|
|
376
|
+
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
377
|
+
if self.metadata.shared > max_shared:
|
|
378
|
+
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
|
|
379
|
+
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
|
|
380
|
+
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
|
|
381
|
+
self.name, self.kernel, self.metadata.shared, device)
|
|
382
|
+
|
|
383
|
+
def __getattribute__(self, name):
|
|
384
|
+
if name == 'run':
|
|
385
|
+
self._init_handles()
|
|
386
|
+
return super().__getattribute__(name)
|
|
387
|
+
|
|
388
|
+
def launch_metadata(self, grid, stream, *args):
|
|
389
|
+
if CompiledKernel.launch_enter_hook is None:
|
|
390
|
+
return None
|
|
391
|
+
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
|
|
392
|
+
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
|
|
393
|
+
return ret
|
|
394
|
+
arg_dict = {}
|
|
395
|
+
arg_idx = 0
|
|
396
|
+
for i, arg_name in enumerate(self.src.fn.arg_names):
|
|
397
|
+
if i in self.src.fn.constexprs:
|
|
398
|
+
arg_dict[arg_name] = self.src.constants[arg_name]
|
|
399
|
+
else:
|
|
400
|
+
arg_dict[arg_name] = args[arg_idx]
|
|
401
|
+
arg_idx += 1
|
|
402
|
+
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
|
|
403
|
+
return ret
|
|
404
|
+
|
|
405
|
+
def __getitem__(self, grid):
|
|
406
|
+
self._init_handles()
|
|
407
|
+
|
|
408
|
+
def runner(*args, stream=None):
|
|
409
|
+
if stream is None:
|
|
410
|
+
device = driver.active.get_current_device()
|
|
411
|
+
stream = driver.active.get_current_stream(device)
|
|
412
|
+
launch_metadata = self.launch_metadata(grid, stream, *args)
|
|
413
|
+
self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
|
|
414
|
+
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
|
|
415
|
+
|
|
416
|
+
return runner
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from ..errors import TritonError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CompilationError(TritonError):
|
|
7
|
+
"""Base class for all errors raised during compilation"""
|
|
8
|
+
source_line_count_max_in_message = 12
|
|
9
|
+
|
|
10
|
+
def _format_message(self) -> str:
|
|
11
|
+
node = self.node
|
|
12
|
+
if self.src is None:
|
|
13
|
+
source_excerpt = " <source unavailable>"
|
|
14
|
+
else:
|
|
15
|
+
if hasattr(node, 'lineno'):
|
|
16
|
+
source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
|
|
17
|
+
if source_excerpt:
|
|
18
|
+
source_excerpt.append(' ' * node.col_offset + '^')
|
|
19
|
+
source_excerpt = '\n'.join(source_excerpt)
|
|
20
|
+
else:
|
|
21
|
+
source_excerpt = " <source empty>"
|
|
22
|
+
else:
|
|
23
|
+
source_excerpt = self.src
|
|
24
|
+
|
|
25
|
+
message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr(
|
|
26
|
+
node, 'lineno') else source_excerpt
|
|
27
|
+
if self.error_message:
|
|
28
|
+
message += '\n' + self.error_message
|
|
29
|
+
return message
|
|
30
|
+
|
|
31
|
+
def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None):
|
|
32
|
+
self.src = src
|
|
33
|
+
self.node = node
|
|
34
|
+
self.error_message = error_message
|
|
35
|
+
self.message = self._format_message()
|
|
36
|
+
|
|
37
|
+
def __str__(self):
|
|
38
|
+
return self.message
|
|
39
|
+
|
|
40
|
+
def __reduce__(self):
|
|
41
|
+
# this is necessary to make CompilationError picklable
|
|
42
|
+
return type(self), (self.src, self.node, self.error_message)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CompileTimeAssertionFailure(CompilationError):
|
|
46
|
+
"""Specific exception for failed tests in `static_assert` invocations"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class UnsupportedLanguageConstruct(CompilationError):
|
|
51
|
+
pass
|
|
File without changes
|