triton-windows 3.3.0.post19__cp313-cp313-win_amd64.whl → 3.4.0.post20__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
- triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
triton/runtime/build.py
CHANGED
|
@@ -1,14 +1,25 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import functools
|
|
2
|
-
import
|
|
4
|
+
import hashlib
|
|
5
|
+
import importlib.util
|
|
6
|
+
import logging
|
|
3
7
|
import os
|
|
4
8
|
import shutil
|
|
5
9
|
import subprocess
|
|
10
|
+
import sysconfig
|
|
11
|
+
import tempfile
|
|
12
|
+
|
|
13
|
+
from types import ModuleType
|
|
14
|
+
|
|
15
|
+
from .cache import get_cache_manager
|
|
16
|
+
from .. import knobs
|
|
6
17
|
|
|
7
18
|
if os.name == "nt":
|
|
8
19
|
from triton.windows_utils import find_msvc_winsdk, find_python
|
|
9
20
|
|
|
10
21
|
|
|
11
|
-
@functools.
|
|
22
|
+
@functools.lru_cache
|
|
12
23
|
def get_cc():
|
|
13
24
|
cc = os.environ.get("CC")
|
|
14
25
|
if cc is None:
|
|
@@ -30,6 +41,11 @@ def get_cc():
|
|
|
30
41
|
return cc
|
|
31
42
|
|
|
32
43
|
|
|
44
|
+
def is_tcc(cc):
|
|
45
|
+
cc = os.path.basename(cc).lower()
|
|
46
|
+
return cc == "tcc" or cc == "tcc.exe"
|
|
47
|
+
|
|
48
|
+
|
|
33
49
|
def is_msvc(cc):
|
|
34
50
|
cc = os.path.basename(cc).lower()
|
|
35
51
|
return cc == "cl" or cc == "cl.exe"
|
|
@@ -58,13 +74,18 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
|
|
|
58
74
|
if not (os.name == "nt" and is_clang(cc)):
|
|
59
75
|
# Clang does not support -fPIC on Windows
|
|
60
76
|
cc_cmd += ["-fPIC"]
|
|
77
|
+
if is_tcc(cc):
|
|
78
|
+
cc_cmd += ["-D_Py_USE_GCC_BUILTIN_ATOMICS"]
|
|
61
79
|
cc_cmd += [f'-l{lib}' for lib in libraries]
|
|
62
80
|
cc_cmd += [f"-L{dir}" for dir in library_dirs]
|
|
63
81
|
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
|
|
64
82
|
return cc_cmd
|
|
65
83
|
|
|
66
84
|
|
|
67
|
-
def _build(name, src, srcdir, library_dirs, include_dirs,
|
|
85
|
+
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
|
|
86
|
+
libraries: list[str]) -> str:
|
|
87
|
+
if impl := knobs.build.impl:
|
|
88
|
+
return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
|
|
68
89
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
|
69
90
|
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
|
70
91
|
# try to avoid setuptools if possible
|
|
@@ -73,24 +94,25 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
|
|
|
73
94
|
if hasattr(sysconfig, 'get_default_scheme'):
|
|
74
95
|
scheme = sysconfig.get_default_scheme()
|
|
75
96
|
else:
|
|
76
|
-
scheme = sysconfig._get_default_scheme()
|
|
97
|
+
scheme = sysconfig._get_default_scheme() # type: ignore
|
|
77
98
|
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
|
|
78
99
|
# path changes to include 'local'. This change is required to use triton with system-wide python.
|
|
79
100
|
if scheme == 'posix_local':
|
|
80
101
|
scheme = 'posix_prefix'
|
|
81
102
|
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
|
82
|
-
custom_backend_dirs =
|
|
103
|
+
custom_backend_dirs = knobs.build.backend_dirs
|
|
104
|
+
# Don't append in place
|
|
83
105
|
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
|
|
84
106
|
if os.name == "nt":
|
|
85
|
-
library_dirs
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
libraries
|
|
107
|
+
library_dirs = library_dirs + find_python()
|
|
108
|
+
version = sysconfig.get_python_version().replace(".", "")
|
|
109
|
+
if sysconfig.get_config_var("Py_GIL_DISABLED"):
|
|
110
|
+
version += "t"
|
|
111
|
+
libraries = libraries + [f"python{version}"]
|
|
90
112
|
if is_msvc(cc):
|
|
91
113
|
_, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
|
|
92
|
-
include_dirs
|
|
93
|
-
library_dirs
|
|
114
|
+
include_dirs = include_dirs + msvc_winsdk_inc_dirs
|
|
115
|
+
library_dirs = library_dirs + msvc_winsdk_lib_dirs
|
|
94
116
|
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
|
|
95
117
|
|
|
96
118
|
try:
|
|
@@ -100,3 +122,45 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
|
|
|
100
122
|
raise e
|
|
101
123
|
|
|
102
124
|
return so
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@functools.lru_cache
|
|
128
|
+
def platform_key() -> str:
|
|
129
|
+
from platform import machine, system, architecture
|
|
130
|
+
return ",".join([machine(), system(), *architecture()])
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _load_module_from_path(name: str, path: str) -> ModuleType:
|
|
134
|
+
# Loading module with relative path may cause error
|
|
135
|
+
path = os.path.abspath(path)
|
|
136
|
+
spec = importlib.util.spec_from_file_location(name, path)
|
|
137
|
+
if not spec or not spec.loader:
|
|
138
|
+
raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
|
|
139
|
+
mod = importlib.util.module_from_spec(spec)
|
|
140
|
+
spec.loader.exec_module(mod)
|
|
141
|
+
return mod
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
|
|
145
|
+
include_dirs: list[str] | None = None, libraries: list[str] | None = None) -> ModuleType:
|
|
146
|
+
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
|
|
147
|
+
cache = get_cache_manager(key)
|
|
148
|
+
suffix = sysconfig.get_config_var("EXT_SUFFIX")
|
|
149
|
+
cache_path = cache.get_file(f"{name}{suffix}")
|
|
150
|
+
|
|
151
|
+
if cache_path is not None:
|
|
152
|
+
try:
|
|
153
|
+
return _load_module_from_path(name, cache_path)
|
|
154
|
+
except (RuntimeError, ImportError):
|
|
155
|
+
log = logging.getLogger(__name__)
|
|
156
|
+
log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
|
|
157
|
+
|
|
158
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
159
|
+
src_path = os.path.join(tmpdir, name + ".c")
|
|
160
|
+
with open(src_path, "w") as f:
|
|
161
|
+
f.write(src)
|
|
162
|
+
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [])
|
|
163
|
+
with open(so, "rb") as f:
|
|
164
|
+
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
|
|
165
|
+
|
|
166
|
+
return _load_module_from_path(name, cache_path)
|
triton/runtime/cache.py
CHANGED
|
@@ -1,33 +1,17 @@
|
|
|
1
|
-
import importlib
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
3
|
import uuid
|
|
5
4
|
from abc import ABC, abstractmethod
|
|
6
|
-
from pathlib import Path
|
|
7
5
|
from typing import Dict, List, Optional
|
|
8
6
|
import base64
|
|
9
7
|
import hashlib
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
def get_home_dir():
|
|
13
|
-
return os.getenv("TRITON_HOME", Path.home())
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def default_cache_dir():
|
|
17
|
-
return os.path.join(get_home_dir(), ".triton", "cache")
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def default_override_dir():
|
|
21
|
-
return os.path.join(get_home_dir(), ".triton", "override")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def default_dump_dir():
|
|
25
|
-
return os.path.join(get_home_dir(), ".triton", "dump")
|
|
9
|
+
from .. import knobs
|
|
26
10
|
|
|
27
11
|
|
|
28
12
|
class CacheManager(ABC):
|
|
29
13
|
|
|
30
|
-
def __init__(self, key):
|
|
14
|
+
def __init__(self, key, override=False, dump=False):
|
|
31
15
|
pass
|
|
32
16
|
|
|
33
17
|
@abstractmethod
|
|
@@ -53,16 +37,16 @@ class FileCacheManager(CacheManager):
|
|
|
53
37
|
self.key = key
|
|
54
38
|
self.lock_path = None
|
|
55
39
|
if dump:
|
|
56
|
-
self.cache_dir =
|
|
40
|
+
self.cache_dir = knobs.cache.dump_dir
|
|
57
41
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
58
42
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
59
43
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
60
44
|
elif override:
|
|
61
|
-
self.cache_dir =
|
|
45
|
+
self.cache_dir = knobs.cache.override_dir
|
|
62
46
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
63
47
|
else:
|
|
64
48
|
# create cache directory if it doesn't exist
|
|
65
|
-
self.cache_dir =
|
|
49
|
+
self.cache_dir = knobs.cache.dir
|
|
66
50
|
if self.cache_dir:
|
|
67
51
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
68
52
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
@@ -166,10 +150,10 @@ class RedisRemoteCacheBackend(RemoteCacheBackend):
|
|
|
166
150
|
def __init__(self, key):
|
|
167
151
|
import redis
|
|
168
152
|
self._key = key
|
|
169
|
-
self._key_fmt =
|
|
153
|
+
self._key_fmt = knobs.cache.redis.key_format
|
|
170
154
|
self._redis = redis.Redis(
|
|
171
|
-
host=
|
|
172
|
-
port=
|
|
155
|
+
host=knobs.cache.redis.host,
|
|
156
|
+
port=knobs.cache.redis.port,
|
|
173
157
|
)
|
|
174
158
|
|
|
175
159
|
def _get_key(self, filename: str) -> str:
|
|
@@ -187,10 +171,10 @@ class RemoteCacheManager(CacheManager):
|
|
|
187
171
|
|
|
188
172
|
def __init__(self, key, override=False, dump=False):
|
|
189
173
|
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
174
|
+
remote_cache_cls = knobs.cache.remote_manager_class
|
|
175
|
+
if not remote_cache_cls:
|
|
176
|
+
raise RuntimeError(
|
|
177
|
+
"Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
|
|
194
178
|
self._backend = remote_cache_cls(key)
|
|
195
179
|
|
|
196
180
|
self._override = override
|
|
@@ -260,37 +244,24 @@ class RemoteCacheManager(CacheManager):
|
|
|
260
244
|
return self.put(grp_contents, grp_filename)
|
|
261
245
|
|
|
262
246
|
|
|
263
|
-
__cache_cls = FileCacheManager
|
|
264
|
-
__cache_cls_nme = "DEFAULT"
|
|
265
|
-
|
|
266
|
-
|
|
267
247
|
def _base32(key):
|
|
268
248
|
# Assume key is a hex string.
|
|
269
249
|
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
270
250
|
|
|
271
251
|
|
|
272
252
|
def get_cache_manager(key) -> CacheManager:
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
|
|
276
|
-
global __cache_cls
|
|
277
|
-
global __cache_cls_nme
|
|
278
|
-
|
|
279
|
-
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
|
280
|
-
module_path, clz_nme = user_cache_manager.split(":")
|
|
281
|
-
module = importlib.import_module(module_path)
|
|
282
|
-
__cache_cls = getattr(module, clz_nme)
|
|
283
|
-
__cache_cls_nme = user_cache_manager
|
|
284
|
-
|
|
285
|
-
return __cache_cls(_base32(key))
|
|
253
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
254
|
+
return cls(_base32(key))
|
|
286
255
|
|
|
287
256
|
|
|
288
257
|
def get_override_manager(key) -> CacheManager:
|
|
289
|
-
|
|
258
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
259
|
+
return cls(_base32(key), override=True)
|
|
290
260
|
|
|
291
261
|
|
|
292
262
|
def get_dump_manager(key) -> CacheManager:
|
|
293
|
-
|
|
263
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
264
|
+
return cls(_base32(key), dump=True)
|
|
294
265
|
|
|
295
266
|
|
|
296
267
|
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
triton/runtime/driver.py
CHANGED
|
@@ -1,59 +1,62 @@
|
|
|
1
|
-
from
|
|
2
|
-
from ..backends import DriverBase
|
|
1
|
+
from __future__ import annotations
|
|
3
2
|
|
|
3
|
+
from ..backends import backends, DriverBase
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
actives = [x.driver for x in backends.values() if x.driver.is_active()]
|
|
7
|
-
if len(actives) != 1:
|
|
8
|
-
raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
|
|
9
|
-
return actives[0]()
|
|
5
|
+
from typing import Any, Callable, Generic, TypeVar, Union
|
|
10
6
|
|
|
11
7
|
|
|
12
|
-
|
|
8
|
+
def _create_driver() -> DriverBase:
|
|
9
|
+
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
|
|
10
|
+
if len(active_drivers) != 1:
|
|
11
|
+
raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
|
|
12
|
+
return active_drivers[0]()
|
|
13
13
|
|
|
14
|
-
|
|
14
|
+
|
|
15
|
+
T = TypeVar("T")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LazyProxy(Generic[T]):
|
|
19
|
+
|
|
20
|
+
def __init__(self, init_fn: Callable[[], T]) -> None:
|
|
15
21
|
self._init_fn = init_fn
|
|
16
|
-
self._obj = None
|
|
22
|
+
self._obj: Union[T, None] = None
|
|
17
23
|
|
|
18
|
-
def _initialize_obj(self):
|
|
24
|
+
def _initialize_obj(self) -> T:
|
|
19
25
|
if self._obj is None:
|
|
20
26
|
self._obj = self._init_fn()
|
|
27
|
+
return self._obj
|
|
21
28
|
|
|
22
|
-
def __getattr__(self, name):
|
|
23
|
-
self._initialize_obj()
|
|
24
|
-
return getattr(self._obj, name)
|
|
29
|
+
def __getattr__(self, name) -> Any:
|
|
30
|
+
return getattr(self._initialize_obj(), name)
|
|
25
31
|
|
|
26
|
-
def __setattr__(self, name, value):
|
|
32
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
27
33
|
if name in ["_init_fn", "_obj"]:
|
|
28
34
|
super().__setattr__(name, value)
|
|
29
35
|
else:
|
|
30
|
-
self._initialize_obj()
|
|
31
|
-
setattr(self._obj, name, value)
|
|
36
|
+
setattr(self._initialize_obj(), name, value)
|
|
32
37
|
|
|
33
|
-
def __delattr__(self, name):
|
|
34
|
-
self._initialize_obj()
|
|
35
|
-
delattr(self._obj, name)
|
|
38
|
+
def __delattr__(self, name: str) -> None:
|
|
39
|
+
delattr(self._initialize_obj(), name)
|
|
36
40
|
|
|
37
|
-
def __repr__(self):
|
|
41
|
+
def __repr__(self) -> str:
|
|
38
42
|
if self._obj is None:
|
|
39
43
|
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
|
|
40
44
|
return repr(self._obj)
|
|
41
45
|
|
|
42
|
-
def __str__(self):
|
|
43
|
-
self._initialize_obj()
|
|
44
|
-
return str(self._obj)
|
|
46
|
+
def __str__(self) -> str:
|
|
47
|
+
return str(self._initialize_obj())
|
|
45
48
|
|
|
46
49
|
|
|
47
50
|
class DriverConfig:
|
|
48
51
|
|
|
49
|
-
def __init__(self):
|
|
50
|
-
self.default = LazyProxy(_create_driver)
|
|
51
|
-
self.active = self.default
|
|
52
|
+
def __init__(self) -> None:
|
|
53
|
+
self.default: LazyProxy[DriverBase] = LazyProxy(_create_driver)
|
|
54
|
+
self.active: Union[LazyProxy[DriverBase], DriverBase] = self.default
|
|
52
55
|
|
|
53
|
-
def set_active(self, driver: DriverBase):
|
|
56
|
+
def set_active(self, driver: DriverBase) -> None:
|
|
54
57
|
self.active = driver
|
|
55
58
|
|
|
56
|
-
def reset_active(self):
|
|
59
|
+
def reset_active(self) -> None:
|
|
57
60
|
self.active = self.default
|
|
58
61
|
|
|
59
62
|
|
triton/runtime/interpreter.py
CHANGED
|
@@ -1,32 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
1
2
|
import ast
|
|
2
3
|
import textwrap
|
|
3
4
|
import inspect
|
|
4
|
-
from typing import Tuple, List
|
|
5
|
+
from typing import Tuple, List, Dict
|
|
5
6
|
|
|
6
7
|
import math
|
|
7
8
|
import numpy as np
|
|
8
9
|
|
|
9
10
|
import triton
|
|
10
11
|
import triton.language as tl
|
|
12
|
+
import dataclasses
|
|
11
13
|
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
from triton.language.semantic import TritonSemantic
|
|
16
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
12
17
|
from .errors import InterpreterError
|
|
13
18
|
from functools import partial
|
|
14
19
|
from .._C.libtriton import interpreter as _interpreter
|
|
15
20
|
from .._C.libtriton import ir as _ir
|
|
16
21
|
|
|
17
22
|
|
|
23
|
+
@dataclass
|
|
18
24
|
class TensorHandle:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
self.dtype = dtype
|
|
29
|
-
self.attr = {}
|
|
25
|
+
'''
|
|
26
|
+
data: numpy array
|
|
27
|
+
dtype: triton type, either pointer_type or scalar_type.
|
|
28
|
+
we don't store block_type here because the shape information is already available in the data field
|
|
29
|
+
attr: a dictionary of attributes
|
|
30
|
+
'''
|
|
31
|
+
data: np.array
|
|
32
|
+
dtype: tl.dtype
|
|
33
|
+
attr: Dict = dataclasses.field(default_factory=dict)
|
|
30
34
|
|
|
31
35
|
def __bool__(self):
|
|
32
36
|
return bool(self.data.all())
|
|
@@ -103,6 +107,7 @@ class TensorDescHandle:
|
|
|
103
107
|
off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
|
|
104
108
|
ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
|
|
105
109
|
masks = masks & (0 <= off) & (off < self.shape[dim].data)
|
|
110
|
+
assert ptrs.dtype == np.uint64
|
|
106
111
|
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
|
|
107
112
|
return ptrs, masks
|
|
108
113
|
|
|
@@ -114,7 +119,7 @@ class InterpreterOptions:
|
|
|
114
119
|
sanitize_overflow: bool = True
|
|
115
120
|
arch: str = None
|
|
116
121
|
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
|
|
117
|
-
|
|
122
|
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
118
123
|
default_dot_input_precision: str = "tf32"
|
|
119
124
|
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
|
|
120
125
|
max_num_imprecise_acc_default: int = 0
|
|
@@ -248,8 +253,8 @@ np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
|
|
|
248
253
|
class ExtraFunctions:
|
|
249
254
|
|
|
250
255
|
@staticmethod
|
|
251
|
-
def _convert_custom_types(input, dst_ty, fp_downcast_rounding,
|
|
252
|
-
return tl.tensor(
|
|
256
|
+
def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
|
|
257
|
+
return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
|
|
253
258
|
|
|
254
259
|
|
|
255
260
|
class InterpreterBuilder:
|
|
@@ -306,6 +311,9 @@ class InterpreterBuilder:
|
|
|
306
311
|
def get_double_ty(self):
|
|
307
312
|
return tl.float64
|
|
308
313
|
|
|
314
|
+
def get_int1_ty(self):
|
|
315
|
+
return tl.int1
|
|
316
|
+
|
|
309
317
|
def get_int8_ty(self):
|
|
310
318
|
return tl.int8
|
|
311
319
|
|
|
@@ -587,11 +595,18 @@ class InterpreterBuilder:
|
|
|
587
595
|
b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
|
|
588
596
|
return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
|
|
589
597
|
|
|
590
|
-
def create_make_range(self, start, stop):
|
|
598
|
+
def create_make_range(self, ret_ty, start, stop):
|
|
591
599
|
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
|
|
592
600
|
|
|
593
|
-
def create_histogram(self, data, bins):
|
|
594
|
-
|
|
601
|
+
def create_histogram(self, data, bins, mask):
|
|
602
|
+
if mask is None:
|
|
603
|
+
mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
|
|
604
|
+
# force all masked elements to zero
|
|
605
|
+
data = np.where(mask.data, data.data, np.zeros_like(data.data))
|
|
606
|
+
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
|
|
607
|
+
# remove overcounted elements
|
|
608
|
+
histogram[0] -= np.logical_not(mask.data).sum()
|
|
609
|
+
return TensorHandle(histogram, tl.int32)
|
|
595
610
|
|
|
596
611
|
def create_gather(self, src, indices, axis):
|
|
597
612
|
return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
|
|
@@ -641,7 +656,8 @@ class InterpreterBuilder:
|
|
|
641
656
|
# Triton only supports splitting the original tensor into two along the last axis
|
|
642
657
|
return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
|
|
643
658
|
|
|
644
|
-
def create_splat(self,
|
|
659
|
+
def create_splat(self, ret_ty, arg):
|
|
660
|
+
shape = ret_ty.shape
|
|
645
661
|
if isinstance(arg.dtype, tl.block_type):
|
|
646
662
|
return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
647
663
|
else: # scalar
|
|
@@ -715,6 +731,7 @@ class InterpreterBuilder:
|
|
|
715
731
|
shape: List[TensorHandle],
|
|
716
732
|
strides: List[TensorHandle],
|
|
717
733
|
tensor_shape: List[int],
|
|
734
|
+
is_signed: bool,
|
|
718
735
|
):
|
|
719
736
|
desc = TensorDescHandle(base, shape, strides, tensor_shape)
|
|
720
737
|
desc.validate()
|
|
@@ -753,15 +770,18 @@ class InterpreterBuilder:
|
|
|
753
770
|
np_type = _get_np_dtype(type)
|
|
754
771
|
if "int" in np_type.name:
|
|
755
772
|
return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
|
|
773
|
+
elif np_type == np.bool_:
|
|
774
|
+
return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
|
|
756
775
|
else:
|
|
757
776
|
raise TypeError(f"unsupported type {type}")
|
|
758
777
|
|
|
759
778
|
|
|
760
779
|
def _patch_attr(obj, name, member, builder):
|
|
780
|
+
semantic = TritonSemantic(builder)
|
|
761
781
|
new_member = lambda *args, member=member, **kwargs: (member(*args, **
|
|
762
782
|
{k: v
|
|
763
783
|
for k, v in kwargs.items()
|
|
764
|
-
if k != "
|
|
784
|
+
if k != "_semantic"}, _semantic=semantic))
|
|
765
785
|
setattr(obj, name, new_member)
|
|
766
786
|
|
|
767
787
|
|
|
@@ -822,12 +842,10 @@ class ReduceScanOpInterface:
|
|
|
822
842
|
|
|
823
843
|
def apply(self, input):
|
|
824
844
|
if not isinstance(input, tuple):
|
|
825
|
-
|
|
845
|
+
return self.apply((input, ))[0]
|
|
826
846
|
self.check_tensor(input)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
def apply_impl(self, input):
|
|
830
|
-
raise NotImplementedError("apply_impl not implemented")
|
|
847
|
+
ret = self.apply_impl(input)
|
|
848
|
+
return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
|
|
831
849
|
|
|
832
850
|
|
|
833
851
|
class ReduceOps(ReduceScanOpInterface):
|
|
@@ -887,7 +905,7 @@ class ReduceOps(ReduceScanOpInterface):
|
|
|
887
905
|
# Take a scalar
|
|
888
906
|
data = data.item()
|
|
889
907
|
ret.append(self.to_tensor(data, input[i].dtype))
|
|
890
|
-
return ret
|
|
908
|
+
return ret
|
|
891
909
|
|
|
892
910
|
def min_max(self, input, val_reduce_op, idx_reduce_op=None):
|
|
893
911
|
# If input is a tuple, it must be (val, index), and we only take val
|
|
@@ -985,7 +1003,7 @@ class ScanOps(ReduceScanOpInterface):
|
|
|
985
1003
|
if self.reverse:
|
|
986
1004
|
for arg in ret:
|
|
987
1005
|
arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
|
|
988
|
-
return
|
|
1006
|
+
return ret
|
|
989
1007
|
|
|
990
1008
|
|
|
991
1009
|
def _patch_reduce_scan():
|
|
@@ -1092,7 +1110,7 @@ def _patch_lang(fn):
|
|
|
1092
1110
|
_patch_builtin(lang.math, interpreter_builder)
|
|
1093
1111
|
_patch_lang_tensor(lang.tensor)
|
|
1094
1112
|
_patch_lang_core(lang)
|
|
1095
|
-
_patch_builtin(tl.core.
|
|
1113
|
+
_patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder)
|
|
1096
1114
|
|
|
1097
1115
|
|
|
1098
1116
|
def _tuple_create(arg, contents):
|
|
@@ -1127,10 +1145,22 @@ def _implicit_cvt(arg):
|
|
|
1127
1145
|
return tl.tensor(handle, ty)
|
|
1128
1146
|
elif isinstance(arg, tuple):
|
|
1129
1147
|
return _tuple_create(arg, map(_implicit_cvt, arg))
|
|
1148
|
+
elif isinstance(arg, TensorDescriptor):
|
|
1149
|
+
strides = [_implicit_cvt(s) for s in arg.strides]
|
|
1150
|
+
assert arg.strides[-1] == 1
|
|
1151
|
+
strides[-1] = tl.constexpr(1)
|
|
1152
|
+
semantic = TritonSemantic(InterpreterBuilder())
|
|
1153
|
+
return semantic.make_tensor_descriptor(
|
|
1154
|
+
base=_implicit_cvt(arg.base),
|
|
1155
|
+
shape=[_implicit_cvt(s) for s in arg.shape],
|
|
1156
|
+
strides=strides,
|
|
1157
|
+
block_shape=[tl.constexpr(b) for b in arg.block_shape],
|
|
1158
|
+
)
|
|
1130
1159
|
return arg
|
|
1131
1160
|
|
|
1132
1161
|
|
|
1133
1162
|
interpreter_builder = InterpreterBuilder()
|
|
1163
|
+
interpreter_semantic = TritonSemantic(interpreter_builder)
|
|
1134
1164
|
|
|
1135
1165
|
|
|
1136
1166
|
def _unwrap_tensor(t):
|
|
@@ -1162,6 +1192,13 @@ class GridExecutor:
|
|
|
1162
1192
|
def _to_cpu(arg):
|
|
1163
1193
|
if isinstance(arg, tuple):
|
|
1164
1194
|
return _tuple_create(arg, map(_to_cpu, arg))
|
|
1195
|
+
elif isinstance(arg, TensorDescriptor):
|
|
1196
|
+
return TensorDescriptor(
|
|
1197
|
+
_to_cpu(arg.base),
|
|
1198
|
+
arg.shape,
|
|
1199
|
+
arg.strides,
|
|
1200
|
+
arg.block_shape,
|
|
1201
|
+
)
|
|
1165
1202
|
elif not hasattr(arg, "data_ptr"):
|
|
1166
1203
|
return arg
|
|
1167
1204
|
|
|
@@ -1195,6 +1232,8 @@ class GridExecutor:
|
|
|
1195
1232
|
elif isinstance(arg_dev, tuple):
|
|
1196
1233
|
for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
|
|
1197
1234
|
_from_cpu(arg_dev, arg_hst)
|
|
1235
|
+
elif isinstance(arg_dev, TensorDescriptor):
|
|
1236
|
+
_from_cpu(arg_dev.base, arg_hst.base)
|
|
1198
1237
|
|
|
1199
1238
|
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
|
1200
1239
|
_from_cpu(arg_dev, arg_hst)
|
|
@@ -1235,6 +1274,8 @@ class GridExecutor:
|
|
|
1235
1274
|
interpreter_builder.set_grid_idx(x, y, z)
|
|
1236
1275
|
self.fn(**args)
|
|
1237
1276
|
except Exception as e:
|
|
1277
|
+
if triton.knobs.compilation.front_end_debugging:
|
|
1278
|
+
raise
|
|
1238
1279
|
raise InterpreterError(repr(e)) from e
|
|
1239
1280
|
# copy arguments back to propagate side-effects
|
|
1240
1281
|
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
|
|
@@ -1249,14 +1290,10 @@ class ASTTransformer(ast.NodeTransformer):
|
|
|
1249
1290
|
if len(names) > 1:
|
|
1250
1291
|
raise ValueError("Multiple assignments are not supported")
|
|
1251
1292
|
# Modify the assignment x = value to
|
|
1252
|
-
#
|
|
1293
|
+
# interpreter_semantic.to_tensor(value, False)
|
|
1253
1294
|
node.value = ast.Call(
|
|
1254
|
-
func=ast.Attribute(
|
|
1255
|
-
|
|
1256
|
-
value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()),
|
|
1257
|
-
attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()),
|
|
1258
|
-
args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()),
|
|
1259
|
-
ast.Constant(value=False)], keywords=[])
|
|
1295
|
+
func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
|
|
1296
|
+
ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
|
|
1260
1297
|
return node
|
|
1261
1298
|
|
|
1262
1299
|
|