triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.4.0.post20__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
triton/backends/amd/driver.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import os
|
|
3
|
-
import hashlib
|
|
4
3
|
import subprocess
|
|
5
|
-
import
|
|
4
|
+
import re
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from triton
|
|
8
|
-
from triton.runtime.cache import get_cache_manager
|
|
6
|
+
from triton import knobs
|
|
9
7
|
from triton.backends.compiler import GPUTarget
|
|
10
8
|
from triton.backends.driver import GPUDriver
|
|
9
|
+
from triton.runtime.build import compile_module_from_src
|
|
10
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
11
11
|
|
|
12
12
|
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
13
|
-
|
|
13
|
+
include_dirs = [os.path.join(dirname, "include")]
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _find_already_mmapped_dylib_on_linux(lib_name):
|
|
@@ -66,8 +66,7 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
66
66
|
lib_name = "libamdhip64.so"
|
|
67
67
|
|
|
68
68
|
# If we are told explicitly what HIP runtime dynamic library to use, obey that.
|
|
69
|
-
env_libhip_path
|
|
70
|
-
if env_libhip_path:
|
|
69
|
+
if env_libhip_path := knobs.amd.libhip_path:
|
|
71
70
|
if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
|
|
72
71
|
return env_libhip_path
|
|
73
72
|
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
|
|
@@ -81,6 +80,12 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
81
80
|
|
|
82
81
|
paths = []
|
|
83
82
|
|
|
83
|
+
# Check backend
|
|
84
|
+
local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
|
|
85
|
+
if os.path.exists(local_lib):
|
|
86
|
+
return local_lib
|
|
87
|
+
paths.append(local_lib)
|
|
88
|
+
|
|
84
89
|
import site
|
|
85
90
|
# First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
|
|
86
91
|
# that we run Triton together with PyTorch. This makes sure we use the same dynamic
|
|
@@ -124,25 +129,6 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
124
129
|
raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
|
|
125
130
|
|
|
126
131
|
|
|
127
|
-
def compile_module_from_src(src, name):
|
|
128
|
-
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
|
|
129
|
-
cache = get_cache_manager(key)
|
|
130
|
-
cache_path = cache.get_file(f"{name}.so")
|
|
131
|
-
if cache_path is None:
|
|
132
|
-
with tempfile.TemporaryDirectory() as tmpdir:
|
|
133
|
-
src_path = os.path.join(tmpdir, f"{name}.c")
|
|
134
|
-
with open(src_path, "w") as f:
|
|
135
|
-
f.write(src)
|
|
136
|
-
so = _build(name, src_path, tmpdir, [], include_dir, [])
|
|
137
|
-
with open(so, "rb") as f:
|
|
138
|
-
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
|
|
139
|
-
import importlib.util
|
|
140
|
-
spec = importlib.util.spec_from_file_location(name, cache_path)
|
|
141
|
-
mod = importlib.util.module_from_spec(spec)
|
|
142
|
-
spec.loader.exec_module(mod)
|
|
143
|
-
return mod
|
|
144
|
-
|
|
145
|
-
|
|
146
132
|
class HIPUtils(object):
|
|
147
133
|
|
|
148
134
|
def __new__(cls):
|
|
@@ -157,7 +143,7 @@ class HIPUtils(object):
|
|
|
157
143
|
# This way we don't need to escape-quote C code curly brackets and we can replace
|
|
158
144
|
# exactly once.
|
|
159
145
|
src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
|
|
160
|
-
mod = compile_module_from_src(src, "hip_utils")
|
|
146
|
+
mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
|
|
161
147
|
self.load_binary = mod.load_binary
|
|
162
148
|
self.get_device_properties = mod.get_device_properties
|
|
163
149
|
|
|
@@ -177,16 +163,60 @@ def ty_to_cpp(ty):
|
|
|
177
163
|
"u16": "uint16_t",
|
|
178
164
|
"u32": "uint32_t",
|
|
179
165
|
"u64": "uint64_t",
|
|
180
|
-
"fp16": "
|
|
181
|
-
"bf16": "
|
|
182
|
-
"fp32": "
|
|
183
|
-
"f32": "
|
|
166
|
+
"fp16": "double",
|
|
167
|
+
"bf16": "double",
|
|
168
|
+
"fp32": "double",
|
|
169
|
+
"f32": "double",
|
|
184
170
|
"fp64": "double",
|
|
185
171
|
}[ty]
|
|
186
172
|
|
|
187
173
|
|
|
174
|
+
FLOAT_STORAGE_TYPE = {
|
|
175
|
+
"fp16": "uint16_t",
|
|
176
|
+
"bf16": "uint16_t",
|
|
177
|
+
"fp32": "uint32_t",
|
|
178
|
+
"f32": "uint32_t",
|
|
179
|
+
"fp64": "uint64_t",
|
|
180
|
+
}
|
|
181
|
+
FLOAT_PACK_FUNCTION = {
|
|
182
|
+
"fp16": "pack_fp16",
|
|
183
|
+
"bf16": "pack_bf16",
|
|
184
|
+
"fp32": "pack_fp32",
|
|
185
|
+
"f32": "pack_fp32",
|
|
186
|
+
"fp64": "pack_fp64",
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
_BASE_ARGS_FORMAT = "piiiKKOOOO"
|
|
190
|
+
|
|
191
|
+
|
|
188
192
|
def make_launcher(constants, signature, warp_size):
|
|
189
193
|
|
|
194
|
+
def _expand_signature(signature):
|
|
195
|
+
output = []
|
|
196
|
+
# Expand tensor descriptor arguments into base pointer, shape, and
|
|
197
|
+
# strides
|
|
198
|
+
for sig in signature:
|
|
199
|
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
|
200
|
+
ndim = sig.count(",") + 1
|
|
201
|
+
dtype = re.match("tensordesc<([^[>]*)", sig).group()
|
|
202
|
+
|
|
203
|
+
output.append("*" + dtype)
|
|
204
|
+
for _ in range(2 * ndim):
|
|
205
|
+
output.append("i64")
|
|
206
|
+
# Currently the host side tensor descriptors get passed in as a
|
|
207
|
+
# tensor desc, shape, and strides. We have no way to use these
|
|
208
|
+
# shape and strides when processing tensor descriptors which is
|
|
209
|
+
# why we provide our own decomposition above. Sadly this means
|
|
210
|
+
# we have to pass the shape and strides twice.
|
|
211
|
+
for _ in range(ndim):
|
|
212
|
+
output.append("i32")
|
|
213
|
+
for _ in range(ndim):
|
|
214
|
+
output.append("i64")
|
|
215
|
+
else:
|
|
216
|
+
output.append(sig)
|
|
217
|
+
|
|
218
|
+
return output
|
|
219
|
+
|
|
190
220
|
def _serialize_signature(sig):
|
|
191
221
|
if isinstance(sig, tuple):
|
|
192
222
|
return ','.join(map(_serialize_signature, sig))
|
|
@@ -198,7 +228,7 @@ def make_launcher(constants, signature, warp_size):
|
|
|
198
228
|
return f"[{val}]"
|
|
199
229
|
if ty[0] == '*':
|
|
200
230
|
return "PyObject*"
|
|
201
|
-
if ty
|
|
231
|
+
if ty == "constexpr":
|
|
202
232
|
return "PyObject*"
|
|
203
233
|
return ty_to_cpp(ty)
|
|
204
234
|
|
|
@@ -208,10 +238,9 @@ def make_launcher(constants, signature, warp_size):
|
|
|
208
238
|
return f"({val})"
|
|
209
239
|
if ty[0] == '*':
|
|
210
240
|
return "O"
|
|
211
|
-
if ty
|
|
241
|
+
if ty == "constexpr":
|
|
212
242
|
return "O"
|
|
213
243
|
return {
|
|
214
|
-
"float": "f",
|
|
215
244
|
"double": "d",
|
|
216
245
|
"long": "l",
|
|
217
246
|
"int8_t": "b",
|
|
@@ -224,21 +253,40 @@ def make_launcher(constants, signature, warp_size):
|
|
|
224
253
|
"uint64_t": "K",
|
|
225
254
|
}[ty_to_cpp(ty)]
|
|
226
255
|
|
|
256
|
+
signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
|
|
257
|
+
|
|
227
258
|
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
|
228
|
-
format =
|
|
259
|
+
format = _BASE_ARGS_FORMAT + args_format
|
|
229
260
|
signature = ','.join(map(_serialize_signature, signature.values()))
|
|
230
261
|
signature = list(filter(bool, signature.split(',')))
|
|
231
262
|
signature = {i: s for i, s in enumerate(signature)}
|
|
232
263
|
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
233
264
|
# Record the end of regular arguments;
|
|
234
265
|
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
235
|
-
|
|
266
|
+
arg_decl_list = []
|
|
267
|
+
for i, ty in signature.items():
|
|
268
|
+
if ty == "constexpr":
|
|
269
|
+
continue
|
|
270
|
+
if ty in FLOAT_STORAGE_TYPE:
|
|
271
|
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
|
272
|
+
else:
|
|
273
|
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
|
274
|
+
arg_decls = ', '.join(arg_decl_list)
|
|
236
275
|
internal_args_list = []
|
|
237
276
|
for i, ty in signature.items():
|
|
238
277
|
if ty[0] == "*":
|
|
239
278
|
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
|
279
|
+
elif ty in FLOAT_STORAGE_TYPE:
|
|
280
|
+
internal_args_list.append(f"_arg{i}_storage")
|
|
240
281
|
elif ty != "constexpr":
|
|
241
282
|
internal_args_list.append(f"_arg{i}")
|
|
283
|
+
|
|
284
|
+
float_storage_decls = [
|
|
285
|
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
|
286
|
+
for i, ty in signature.items()
|
|
287
|
+
if ty in FLOAT_STORAGE_TYPE
|
|
288
|
+
]
|
|
289
|
+
|
|
242
290
|
libhip_path = _get_path_to_hip_runtime_dylib()
|
|
243
291
|
|
|
244
292
|
# generate glue code
|
|
@@ -291,9 +339,6 @@ static struct HIPSymbolTable hipSymbolTable;
|
|
|
291
339
|
bool initSymbolTable() {{
|
|
292
340
|
// Use the HIP runtime library loaded into the existing process if it exits.
|
|
293
341
|
void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
|
|
294
|
-
if (lib) {{
|
|
295
|
-
// printf("[triton] chosen loaded libamdhip64.so in the process\\n");
|
|
296
|
-
}}
|
|
297
342
|
|
|
298
343
|
// Otherwise, go through the list of search paths to dlopen the first HIP
|
|
299
344
|
// driver library.
|
|
@@ -303,7 +348,6 @@ bool initSymbolTable() {{
|
|
|
303
348
|
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
|
|
304
349
|
if (handle) {{
|
|
305
350
|
lib = handle;
|
|
306
|
-
// printf("[triton] chosen %s\\n", hipLibSearchPaths[i]);
|
|
307
351
|
}}
|
|
308
352
|
}}
|
|
309
353
|
}}
|
|
@@ -345,7 +389,6 @@ static inline void gpuAssert(hipError_t code, const char *file, int line)
|
|
|
345
389
|
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
|
346
390
|
|
|
347
391
|
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
348
|
-
// printf("_launch hip kernel\\n");
|
|
349
392
|
hipDeviceptr_t global_scratch = 0;
|
|
350
393
|
void *params[] = {{ {', '.join(params)} }};
|
|
351
394
|
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
|
|
@@ -383,11 +426,14 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
|
383
426
|
if (!PyLong_Check(ret)) {{
|
|
384
427
|
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
385
428
|
ptr_info.valid = false;
|
|
429
|
+
Py_DECREF(ret);
|
|
386
430
|
return ptr_info;
|
|
387
431
|
}}
|
|
388
432
|
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
|
389
|
-
if(!ptr_info.dev_ptr)
|
|
433
|
+
if(!ptr_info.dev_ptr) {{
|
|
434
|
+
Py_DECREF(ret);
|
|
390
435
|
return ptr_info;
|
|
436
|
+
}}
|
|
391
437
|
uint64_t dev_ptr;
|
|
392
438
|
hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
393
439
|
if (status == hipErrorInvalidValue) {{
|
|
@@ -403,8 +449,33 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
|
403
449
|
return ptr_info;
|
|
404
450
|
}}
|
|
405
451
|
|
|
452
|
+
static uint16_t pack_fp16(double f) {{
|
|
453
|
+
uint16_t result;
|
|
454
|
+
// from https://github.com/python/pythoncapi-compat
|
|
455
|
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
|
456
|
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
457
|
+
#else
|
|
458
|
+
PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
459
|
+
#endif
|
|
460
|
+
return result;
|
|
461
|
+
}}
|
|
462
|
+
|
|
463
|
+
static uint16_t pack_bf16(double f) {{
|
|
464
|
+
float f32 = (float)f;
|
|
465
|
+
uint32_t u32 = *(uint32_t*)&f32;
|
|
466
|
+
return (uint16_t)(u32 >> 16);
|
|
467
|
+
}}
|
|
468
|
+
|
|
469
|
+
static uint32_t pack_fp32(double f) {{
|
|
470
|
+
float f32 = (float)f;
|
|
471
|
+
return *(uint32_t*)&f32;
|
|
472
|
+
}}
|
|
473
|
+
|
|
474
|
+
static uint64_t pack_fp64(double f) {{
|
|
475
|
+
return *(uint64_t*)&f;
|
|
476
|
+
}}
|
|
477
|
+
|
|
406
478
|
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
407
|
-
// printf("launch\\n");
|
|
408
479
|
int gridX, gridY, gridZ;
|
|
409
480
|
uint64_t _stream;
|
|
410
481
|
uint64_t _function;
|
|
@@ -421,6 +492,8 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
421
492
|
return NULL;
|
|
422
493
|
}}
|
|
423
494
|
|
|
495
|
+
{' '.join(float_storage_decls)}
|
|
496
|
+
|
|
424
497
|
// extract kernel metadata
|
|
425
498
|
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
|
|
426
499
|
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
|
|
@@ -433,6 +506,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
433
506
|
Py_DECREF(args);
|
|
434
507
|
if (!ret)
|
|
435
508
|
return NULL;
|
|
509
|
+
Py_DECREF(ret);
|
|
436
510
|
}}
|
|
437
511
|
|
|
438
512
|
|
|
@@ -446,6 +520,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
446
520
|
Py_DECREF(args);
|
|
447
521
|
if (!ret)
|
|
448
522
|
return NULL;
|
|
523
|
+
Py_DECREF(ret);
|
|
449
524
|
}}
|
|
450
525
|
|
|
451
526
|
if(PyErr_Occurred()) {{
|
|
@@ -484,6 +559,31 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
|
484
559
|
return src
|
|
485
560
|
|
|
486
561
|
|
|
562
|
+
def wrap_handle_tensor_descriptor(launcher):
|
|
563
|
+
"""
|
|
564
|
+
Replace all tensor descriptors with the base ptr, shape, and strides
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
def inner(*args):
|
|
568
|
+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
|
|
569
|
+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
|
|
570
|
+
final_args = []
|
|
571
|
+
for arg in raw_kernel_args:
|
|
572
|
+
if isinstance(arg, TensorDescriptor):
|
|
573
|
+
# Currently the host side tensor descriptors get decomposed in
|
|
574
|
+
# the frontend to tensor desc, shape, and strides. We have no
|
|
575
|
+
# way to use these shape and strides when processing tensor
|
|
576
|
+
# descriptors which is why we provide our own decomposition
|
|
577
|
+
# above. Sadly this means we have to pass the shape and strides
|
|
578
|
+
# twice.
|
|
579
|
+
final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides])
|
|
580
|
+
else:
|
|
581
|
+
final_args.append(arg)
|
|
582
|
+
return launcher(*meta_args, *final_args)
|
|
583
|
+
|
|
584
|
+
return inner
|
|
585
|
+
|
|
586
|
+
|
|
487
587
|
class HIPLauncher(object):
|
|
488
588
|
|
|
489
589
|
def __init__(self, src, metadata):
|
|
@@ -492,8 +592,10 @@ class HIPLauncher(object):
|
|
|
492
592
|
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
|
493
593
|
signature = {idx: value for idx, value in src.signature.items()}
|
|
494
594
|
src = make_launcher(constants, signature, metadata.warp_size)
|
|
495
|
-
mod = compile_module_from_src(src, "__triton_launcher")
|
|
496
|
-
|
|
595
|
+
mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
|
|
596
|
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
|
597
|
+
|
|
598
|
+
self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
|
|
497
599
|
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
498
600
|
|
|
499
601
|
def __call__(self, *args):
|
|
@@ -515,14 +617,14 @@ class HIPDriver(GPUDriver):
|
|
|
515
617
|
def is_active():
|
|
516
618
|
try:
|
|
517
619
|
import torch
|
|
518
|
-
return torch.version.hip is not None
|
|
620
|
+
return torch.cuda.is_available() and (torch.version.hip is not None)
|
|
519
621
|
except ImportError:
|
|
520
622
|
return False
|
|
521
623
|
|
|
522
624
|
def get_current_target(self):
|
|
523
625
|
device = self.get_current_device()
|
|
524
626
|
device_properties = self.utils.get_device_properties(device)
|
|
525
|
-
arch = device_properties['arch']
|
|
627
|
+
arch = knobs.runtime.override_arch or device_properties['arch']
|
|
526
628
|
warp_size = device_properties['warpSize']
|
|
527
629
|
return GPUTarget("hip", arch.split(':')[0], warp_size)
|
|
528
630
|
|
triton/backends/compiler.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import re
|
|
3
|
-
import subprocess
|
|
4
|
-
import sysconfig
|
|
5
1
|
from abc import ABCMeta, abstractmethod
|
|
6
2
|
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
7
4
|
from typing import Dict, Union
|
|
8
5
|
from types import ModuleType
|
|
9
6
|
|
|
@@ -17,6 +14,12 @@ class GPUTarget(object):
|
|
|
17
14
|
warp_size: int
|
|
18
15
|
|
|
19
16
|
|
|
17
|
+
class Language(Enum):
|
|
18
|
+
"""The input language being compiled by the backend."""
|
|
19
|
+
TRITON = 0
|
|
20
|
+
GLUON = 1
|
|
21
|
+
|
|
22
|
+
|
|
20
23
|
class BaseBackend(metaclass=ABCMeta):
|
|
21
24
|
|
|
22
25
|
def __init__(self, target: GPUTarget) -> None:
|
|
@@ -24,23 +27,6 @@ class BaseBackend(metaclass=ABCMeta):
|
|
|
24
27
|
assert self.supports_target(target)
|
|
25
28
|
|
|
26
29
|
@staticmethod
|
|
27
|
-
def _path_to_binary(binary: str):
|
|
28
|
-
binary += sysconfig.get_config_var("EXE")
|
|
29
|
-
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
|
30
|
-
paths = [
|
|
31
|
-
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
32
|
-
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
|
|
33
|
-
]
|
|
34
|
-
for path in paths:
|
|
35
|
-
if os.path.exists(path) and os.path.isfile(path):
|
|
36
|
-
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
37
|
-
if result is not None:
|
|
38
|
-
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
39
|
-
if version is not None:
|
|
40
|
-
return path, version.group(1)
|
|
41
|
-
raise RuntimeError(f"Cannot find {binary}")
|
|
42
|
-
|
|
43
|
-
@classmethod
|
|
44
30
|
@abstractmethod
|
|
45
31
|
def supports_target(target: GPUTarget):
|
|
46
32
|
raise NotImplementedError
|
|
Binary file
|