triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +73 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +262 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +497 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +76 -0
- triton/backends/driver.py +34 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +347 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +430 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1302 -0
- triton/compiler/compiler.py +416 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +284 -0
- triton/language/core.py +2621 -0
- triton/language/extra/__init__.py +4 -0
- triton/language/extra/cuda/__init__.py +8 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +468 -0
- triton/language/extra/libdevice.py +1213 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1621 -0
- triton/language/standard.py +441 -0
- triton/ops/__init__.py +7 -0
- triton/ops/blocksparse/__init__.py +7 -0
- triton/ops/blocksparse/matmul.py +432 -0
- triton/ops/blocksparse/softmax.py +228 -0
- triton/ops/cross_entropy.py +96 -0
- triton/ops/flash_attention.py +466 -0
- triton/ops/matmul.py +219 -0
- triton/ops/matmul_perf_model.py +171 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +361 -0
- triton/runtime/build.py +129 -0
- triton/runtime/cache.py +289 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1127 -0
- triton/runtime/jit.py +956 -0
- triton/runtime/tcc/include/_mingw.h +170 -0
- triton/runtime/tcc/include/assert.h +57 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +57 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/limits.h +111 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +737 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdarg.h +79 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +54 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +580 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +118 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/winbase.h +2951 -0
- triton/runtime/tcc/include/winapi/wincon.h +301 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnt.h +5835 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +496 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +145 -0
- triton/tools/disasm.py +142 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +373 -0
- triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
- triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
- triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
- triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import subprocess
|
|
4
|
+
|
|
5
|
+
from abc import ABCMeta, abstractmethod, abstractclassmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Union
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class GPUTarget(object):
|
|
12
|
+
# Target backend, e.g., cuda, hip
|
|
13
|
+
backend: str
|
|
14
|
+
# Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
|
|
15
|
+
arch: Union[int, str]
|
|
16
|
+
warp_size: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseBackend(metaclass=ABCMeta):
|
|
20
|
+
|
|
21
|
+
def __init__(self, target: GPUTarget) -> None:
|
|
22
|
+
self.target = target
|
|
23
|
+
assert self.supports_target(target)
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def _path_to_binary(binary: str):
|
|
27
|
+
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
|
28
|
+
paths = [
|
|
29
|
+
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
30
|
+
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
|
|
31
|
+
]
|
|
32
|
+
for p in paths:
|
|
33
|
+
bin = p.split(" ")[0]
|
|
34
|
+
if os.path.exists(bin) and os.path.isfile(bin):
|
|
35
|
+
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
|
|
36
|
+
if result is not None:
|
|
37
|
+
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
38
|
+
if version is not None:
|
|
39
|
+
return p, version.group(1)
|
|
40
|
+
raise RuntimeError(f"Cannot find {binary}")
|
|
41
|
+
|
|
42
|
+
@abstractclassmethod
|
|
43
|
+
def supports_target(target: GPUTarget):
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def hash(self) -> str:
|
|
48
|
+
"""Returns a unique identifier for this backend"""
|
|
49
|
+
raise NotImplementedError
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def parse_options(self, options: dict) -> object:
|
|
53
|
+
"""
|
|
54
|
+
Converts an `options` dictionary into an arbitrary object and returns it.
|
|
55
|
+
This function may contain target-specific heuristics and check the legality of the provided options
|
|
56
|
+
"""
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def add_stages(self, stages: dict, options: object) -> None:
|
|
61
|
+
"""
|
|
62
|
+
Populates `stages` dictionary with entries of the form:
|
|
63
|
+
ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
|
|
64
|
+
The value of each entry may populate a `metadata` dictionary.
|
|
65
|
+
Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
|
|
66
|
+
All stages are expected to return a `str` object, except for the last stage which returns
|
|
67
|
+
a `bytes` object for execution by the launcher.
|
|
68
|
+
"""
|
|
69
|
+
raise NotImplementedError
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def load_dialects(self, context):
|
|
73
|
+
"""
|
|
74
|
+
Load additional MLIR dialects into the provided `context`
|
|
75
|
+
"""
|
|
76
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod, abstractclassmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DriverBase(metaclass=ABCMeta):
|
|
5
|
+
|
|
6
|
+
@abstractclassmethod
|
|
7
|
+
def is_active(self):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def get_current_target(self):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
def __init__(self) -> None:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GPUDriver(DriverBase):
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
# TODO: support other frameworks than torch
|
|
22
|
+
import torch
|
|
23
|
+
self.get_device_capability = torch.cuda.get_device_capability
|
|
24
|
+
try:
|
|
25
|
+
from torch._C import _cuda_getCurrentRawStream
|
|
26
|
+
self.get_current_stream = _cuda_getCurrentRawStream
|
|
27
|
+
except ImportError:
|
|
28
|
+
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
|
|
29
|
+
self.get_current_device = torch.cuda.current_device
|
|
30
|
+
self.set_current_device = torch.cuda.set_device
|
|
31
|
+
|
|
32
|
+
# TODO: remove once TMA is cleaned up
|
|
33
|
+
def assemble_tensormap_to_arg(self, tensormaps_info, args):
|
|
34
|
+
return args
|
|
File without changes
|
|
Binary file
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
from triton.backends.compiler import BaseBackend, GPUTarget
|
|
2
|
+
from triton._C.libtriton import ir, passes, llvm, nvidia
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
import functools
|
|
6
|
+
from typing import Any, Tuple, Optional
|
|
7
|
+
import hashlib
|
|
8
|
+
import re
|
|
9
|
+
import tempfile
|
|
10
|
+
import signal
|
|
11
|
+
import os
|
|
12
|
+
import subprocess
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@functools.lru_cache()
|
|
17
|
+
def _path_to_binary(binary: str):
|
|
18
|
+
paths = [
|
|
19
|
+
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
20
|
+
]
|
|
21
|
+
if os.name == "nt":
|
|
22
|
+
binary += ".exe"
|
|
23
|
+
paths += [
|
|
24
|
+
os.path.join(os.path.dirname(__file__), "bin", binary),
|
|
25
|
+
]
|
|
26
|
+
if os.name == "nt":
|
|
27
|
+
from triton.windows_utils import find_cuda
|
|
28
|
+
cuda_bin_path, _, _ = find_cuda()
|
|
29
|
+
if cuda_bin_path:
|
|
30
|
+
paths += [os.path.join(cuda_bin_path, binary)]
|
|
31
|
+
|
|
32
|
+
for bin in paths:
|
|
33
|
+
if os.path.exists(bin) and os.path.isfile(bin):
|
|
34
|
+
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
|
|
35
|
+
if result is not None:
|
|
36
|
+
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
37
|
+
if version is not None:
|
|
38
|
+
return bin, version.group(1)
|
|
39
|
+
raise RuntimeError(f"Cannot find {binary}")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@functools.lru_cache()
|
|
43
|
+
def get_ptxas_version():
|
|
44
|
+
version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8")
|
|
45
|
+
return version
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@functools.lru_cache()
|
|
49
|
+
def ptx_get_version(cuda_version) -> int:
|
|
50
|
+
'''
|
|
51
|
+
Get the highest PTX version supported by the current CUDA driver.
|
|
52
|
+
'''
|
|
53
|
+
assert isinstance(cuda_version, str)
|
|
54
|
+
major, minor = map(int, cuda_version.split('.'))
|
|
55
|
+
if major == 12:
|
|
56
|
+
if minor < 6:
|
|
57
|
+
return 80 + minor
|
|
58
|
+
else:
|
|
59
|
+
return 79 + minor
|
|
60
|
+
if major == 11:
|
|
61
|
+
return 70 + minor
|
|
62
|
+
if major == 10:
|
|
63
|
+
return 63 + minor
|
|
64
|
+
raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@functools.lru_cache(None)
|
|
68
|
+
def file_hash(path):
|
|
69
|
+
with open(path, "rb") as f:
|
|
70
|
+
return hashlib.sha256(f.read()).hexdigest()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# The file may be accessed in parallel
|
|
74
|
+
def try_remove(path):
|
|
75
|
+
if os.path.exists(path):
|
|
76
|
+
try:
|
|
77
|
+
os.remove(path)
|
|
78
|
+
except OSError:
|
|
79
|
+
import traceback
|
|
80
|
+
traceback.print_exc()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass(frozen=True)
|
|
84
|
+
class CUDAOptions:
|
|
85
|
+
num_warps: int = 4
|
|
86
|
+
num_ctas: int = 1
|
|
87
|
+
num_stages: int = 3
|
|
88
|
+
# maxnreg corresponds to the ptx parameter .maxnreg, which controls the
|
|
89
|
+
# maximum number of 32-bit registers used by one thread.
|
|
90
|
+
maxnreg: Optional[int] = None
|
|
91
|
+
cluster_dims: tuple = (1, 1, 1)
|
|
92
|
+
ptx_version: int = None
|
|
93
|
+
enable_fp_fusion: bool = True
|
|
94
|
+
allow_fp8e4nv: bool = False
|
|
95
|
+
allow_fp8e4b15: bool = False
|
|
96
|
+
default_dot_input_precision: str = "tf32"
|
|
97
|
+
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
|
|
98
|
+
max_num_imprecise_acc_default: bool = None
|
|
99
|
+
extern_libs: dict = None
|
|
100
|
+
debug: bool = False
|
|
101
|
+
backend_name: str = 'cuda'
|
|
102
|
+
|
|
103
|
+
def __post_init__(self):
|
|
104
|
+
default_libdir = Path(__file__).parent / 'lib'
|
|
105
|
+
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
|
106
|
+
if not extern_libs.get('libdevice', None):
|
|
107
|
+
extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc'))
|
|
108
|
+
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
|
|
109
|
+
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
110
|
+
"num_warps must be a power of 2"
|
|
111
|
+
|
|
112
|
+
def hash(self):
|
|
113
|
+
hash_dict = dict(self.__dict__)
|
|
114
|
+
hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
|
|
115
|
+
key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
|
|
116
|
+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class CUDABackend(BaseBackend):
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def supports_target(target: GPUTarget):
|
|
123
|
+
return target.backend == 'cuda'
|
|
124
|
+
|
|
125
|
+
def __init__(self, target: GPUTarget) -> None:
|
|
126
|
+
super().__init__(target)
|
|
127
|
+
self.capability = target.arch
|
|
128
|
+
assert isinstance(self.capability, int)
|
|
129
|
+
self.binary_ext = "cubin"
|
|
130
|
+
|
|
131
|
+
def parse_options(self, opts) -> Any:
|
|
132
|
+
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
|
|
133
|
+
args["allow_fp8e4nv"] = self.capability >= 89
|
|
134
|
+
args["allow_fp8e4b15"] = self.capability < 90
|
|
135
|
+
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
|
|
136
|
+
return CUDAOptions(**args)
|
|
137
|
+
|
|
138
|
+
def pack_metadata(self, metadata):
|
|
139
|
+
return (
|
|
140
|
+
metadata.num_warps,
|
|
141
|
+
metadata.num_ctas,
|
|
142
|
+
metadata.shared,
|
|
143
|
+
metadata.cluster_dims[0],
|
|
144
|
+
metadata.cluster_dims[1],
|
|
145
|
+
metadata.cluster_dims[2],
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def get_codegen_implementation(self):
|
|
149
|
+
import triton.language.extra.cuda as cuda
|
|
150
|
+
codegen_fns = {
|
|
151
|
+
"convert_custom_types":
|
|
152
|
+
cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70
|
|
153
|
+
}
|
|
154
|
+
return codegen_fns
|
|
155
|
+
|
|
156
|
+
def load_dialects(self, ctx):
|
|
157
|
+
nvidia.load_dialects(ctx)
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def make_ttir(mod, metadata, opt):
|
|
161
|
+
pm = ir.pass_manager(mod.context)
|
|
162
|
+
pm.enable_debug()
|
|
163
|
+
passes.common.add_inliner(pm)
|
|
164
|
+
passes.ttir.add_rewrite_tensor_pointer(pm)
|
|
165
|
+
passes.ttir.add_combine(pm)
|
|
166
|
+
passes.common.add_canonicalizer(pm)
|
|
167
|
+
passes.ttir.add_reorder_broadcast(pm)
|
|
168
|
+
passes.common.add_cse(pm)
|
|
169
|
+
passes.common.add_licm(pm)
|
|
170
|
+
passes.common.add_symbol_dce(pm)
|
|
171
|
+
pm.run(mod)
|
|
172
|
+
return mod
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def make_ttgir(mod, metadata, opt, capability):
|
|
176
|
+
cluster_info = nvidia.ClusterInfo()
|
|
177
|
+
if opt.cluster_dims is not None:
|
|
178
|
+
cluster_info.clusterDimX = opt.cluster_dims[0]
|
|
179
|
+
cluster_info.clusterDimY = opt.cluster_dims[1]
|
|
180
|
+
cluster_info.clusterDimZ = opt.cluster_dims[2]
|
|
181
|
+
# TTIR -> TTGIR
|
|
182
|
+
pm = ir.pass_manager(mod.context)
|
|
183
|
+
pm.enable_debug()
|
|
184
|
+
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
|
|
185
|
+
# optimize TTGIR
|
|
186
|
+
passes.ttgpuir.add_coalesce(pm)
|
|
187
|
+
if capability // 10 >= 8:
|
|
188
|
+
passes.ttgpuir.add_f32_dot_tc(pm)
|
|
189
|
+
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
|
|
190
|
+
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
|
|
191
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
192
|
+
passes.ttgpuir.add_optimize_thread_locality(pm)
|
|
193
|
+
passes.ttgpuir.add_accelerate_matmul(pm)
|
|
194
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
195
|
+
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
196
|
+
passes.common.add_cse(pm)
|
|
197
|
+
if capability // 10 >= 8:
|
|
198
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
199
|
+
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
|
|
200
|
+
passes.ttgpuir.add_prefetch(pm)
|
|
201
|
+
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
202
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
203
|
+
passes.ttgpuir.add_reduce_data_duplication(pm)
|
|
204
|
+
passes.ttgpuir.add_reorder_instructions(pm)
|
|
205
|
+
passes.common.add_cse(pm)
|
|
206
|
+
passes.common.add_symbol_dce(pm)
|
|
207
|
+
if capability // 10 >= 9:
|
|
208
|
+
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
|
|
209
|
+
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
|
|
210
|
+
passes.common.add_canonicalizer(pm)
|
|
211
|
+
pm.run(mod)
|
|
212
|
+
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
|
|
213
|
+
return mod
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
def make_llir(src, metadata, options, capability):
|
|
217
|
+
# warp-specialization mutates num_warps
|
|
218
|
+
num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
|
|
219
|
+
if num_warp_groups is not None:
|
|
220
|
+
metadata["num_warps"] *= num_warp_groups
|
|
221
|
+
mod = src
|
|
222
|
+
# TritonGPU -> LLVM-IR (MLIR)
|
|
223
|
+
pm = ir.pass_manager(mod.context)
|
|
224
|
+
pm.enable_debug()
|
|
225
|
+
nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
|
|
226
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
227
|
+
passes.convert.add_scf_to_cf(pm)
|
|
228
|
+
passes.convert.add_index_to_llvmir(pm)
|
|
229
|
+
passes.ttgpuir.add_allocate_shared_memory(pm)
|
|
230
|
+
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
|
|
231
|
+
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
|
|
232
|
+
passes.convert.add_arith_to_llvmir(pm)
|
|
233
|
+
passes.common.add_canonicalizer(pm)
|
|
234
|
+
passes.common.add_cse(pm)
|
|
235
|
+
passes.common.add_symbol_dce(pm)
|
|
236
|
+
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
|
|
237
|
+
passes.llvmir.add_di_scope(pm)
|
|
238
|
+
pm.run(mod)
|
|
239
|
+
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
|
|
240
|
+
llvm.init_targets()
|
|
241
|
+
context = llvm.context()
|
|
242
|
+
llvm_mod = llvm.to_module(mod, context)
|
|
243
|
+
nvidia.set_nvvm_reflect_ftz(llvm_mod)
|
|
244
|
+
|
|
245
|
+
# Set maxnreg on all kernels, if it was provided.
|
|
246
|
+
if options.maxnreg is not None:
|
|
247
|
+
for k in llvm_mod.get_functions():
|
|
248
|
+
if not k.is_declaration() and k.is_external_linkage():
|
|
249
|
+
k.set_nvvm_maxnreg(options.maxnreg)
|
|
250
|
+
|
|
251
|
+
if options.extern_libs:
|
|
252
|
+
paths = [path for (name, path) in options.extern_libs]
|
|
253
|
+
llvm.link_extern_libs(llvm_mod, paths)
|
|
254
|
+
|
|
255
|
+
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
|
|
256
|
+
|
|
257
|
+
# Get some metadata
|
|
258
|
+
metadata["shared"] = src.get_int_attr("triton_gpu.shared")
|
|
259
|
+
ret = str(llvm_mod)
|
|
260
|
+
del llvm_mod
|
|
261
|
+
del context
|
|
262
|
+
return ret
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def make_ptx(src, metadata, opt, capability):
|
|
266
|
+
ptx_version = opt.ptx_version
|
|
267
|
+
if ptx_version is None:
|
|
268
|
+
_, cuda_version = _path_to_binary("ptxas")
|
|
269
|
+
ptx_version = ptx_get_version(cuda_version)
|
|
270
|
+
|
|
271
|
+
# PTX 8.3 is the max version supported by llvm 3a83162168.
|
|
272
|
+
#
|
|
273
|
+
# To check if a newer PTX version is supported, increase this value
|
|
274
|
+
# and run a test. If it's not supported, LLVM will print a warning
|
|
275
|
+
# like "+ptx8.4 is not a recognized feature for this target".
|
|
276
|
+
llvm_ptx_version = min(83, ptx_version)
|
|
277
|
+
|
|
278
|
+
triple = 'nvptx64-nvidia-cuda'
|
|
279
|
+
proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
|
|
280
|
+
features = f'+ptx{llvm_ptx_version}'
|
|
281
|
+
ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
|
|
282
|
+
# Find kernel names (there should only be one)
|
|
283
|
+
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
|
|
284
|
+
assert len(names) == 1
|
|
285
|
+
metadata["name"] = names[0]
|
|
286
|
+
# post-process
|
|
287
|
+
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
|
|
288
|
+
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
|
|
289
|
+
# Remove the debug flag that prevents ptxas from optimizing the code
|
|
290
|
+
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
|
|
291
|
+
if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
|
|
292
|
+
print("// -----// NVPTX Dump //----- //")
|
|
293
|
+
print(ret)
|
|
294
|
+
return ret
|
|
295
|
+
|
|
296
|
+
@staticmethod
|
|
297
|
+
def make_cubin(src, metadata, opt, capability):
|
|
298
|
+
ptxas, _ = _path_to_binary("ptxas")
|
|
299
|
+
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
|
|
300
|
+
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
|
|
301
|
+
fsrc.write(src)
|
|
302
|
+
fsrc.flush()
|
|
303
|
+
fbin = fsrc.name + '.o'
|
|
304
|
+
|
|
305
|
+
line_info = [] if os.environ.get('TRITON_DISABLE_LINE_INFO') else ['-lineinfo']
|
|
306
|
+
fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
|
|
307
|
+
suffix = 'a' if capability == 90 else ''
|
|
308
|
+
opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
|
|
309
|
+
ptxas_cmd = [
|
|
310
|
+
ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin
|
|
311
|
+
]
|
|
312
|
+
try:
|
|
313
|
+
# close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
|
|
314
|
+
# On Windows, both stdout and stderr need to be redirected to flog
|
|
315
|
+
subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog, stderr=flog)
|
|
316
|
+
except subprocess.CalledProcessError as e:
|
|
317
|
+
with open(flog.name) as log_file:
|
|
318
|
+
log = log_file.read()
|
|
319
|
+
|
|
320
|
+
if e.returncode == 255:
|
|
321
|
+
raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}')
|
|
322
|
+
elif e.returncode == 128 + signal.SIGSEGV:
|
|
323
|
+
raise RuntimeError(
|
|
324
|
+
f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}')
|
|
325
|
+
else:
|
|
326
|
+
raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}')
|
|
327
|
+
|
|
328
|
+
with open(fbin, 'rb') as f:
|
|
329
|
+
cubin = f.read()
|
|
330
|
+
try_remove(fbin)
|
|
331
|
+
|
|
332
|
+
# It's better to remove the temp files outside the context managers
|
|
333
|
+
try_remove(fsrc.name)
|
|
334
|
+
try_remove(flog.name)
|
|
335
|
+
return cubin
|
|
336
|
+
|
|
337
|
+
def add_stages(self, stages, options):
|
|
338
|
+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
|
|
339
|
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
|
|
340
|
+
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
|
|
341
|
+
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
|
|
342
|
+
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
|
|
343
|
+
|
|
344
|
+
@functools.lru_cache()
|
|
345
|
+
def hash(self):
|
|
346
|
+
version = get_ptxas_version()
|
|
347
|
+
return f'{version}-{self.capability}'
|