triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/_C/libtriton.pyd
CHANGED
|
Binary file
|
triton/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""isort:skip_file"""
|
|
2
|
-
__version__ = '3.
|
|
2
|
+
__version__ = '3.5.0'
|
|
3
3
|
|
|
4
4
|
# ---------------------------------------
|
|
5
5
|
# Note: import order is significant here.
|
|
@@ -17,7 +17,8 @@ from .runtime import (
|
|
|
17
17
|
InterpreterError,
|
|
18
18
|
MockTensor,
|
|
19
19
|
)
|
|
20
|
-
from .runtime.jit import jit
|
|
20
|
+
from .runtime.jit import constexpr_function, jit
|
|
21
|
+
from .runtime._async_compile import AsyncCompileMode, FutureKernel
|
|
21
22
|
from .compiler import compile, CompilationError
|
|
22
23
|
from .errors import TritonError
|
|
23
24
|
from .runtime._allocation import set_allocator
|
|
@@ -26,12 +27,17 @@ from . import language
|
|
|
26
27
|
from . import testing
|
|
27
28
|
from . import tools
|
|
28
29
|
|
|
30
|
+
must_use_result = language.core.must_use_result
|
|
31
|
+
|
|
29
32
|
__all__ = [
|
|
33
|
+
"AsyncCompileMode",
|
|
30
34
|
"autotune",
|
|
31
35
|
"cdiv",
|
|
32
36
|
"CompilationError",
|
|
33
37
|
"compile",
|
|
34
38
|
"Config",
|
|
39
|
+
"constexpr_function",
|
|
40
|
+
"FutureKernel",
|
|
35
41
|
"heuristics",
|
|
36
42
|
"InterpreterError",
|
|
37
43
|
"jit",
|
|
@@ -39,6 +45,7 @@ __all__ = [
|
|
|
39
45
|
"KernelInterface",
|
|
40
46
|
"language",
|
|
41
47
|
"MockTensor",
|
|
48
|
+
"must_use_result",
|
|
42
49
|
"next_power_of_2",
|
|
43
50
|
"OutOfResources",
|
|
44
51
|
"reinterpret",
|
|
@@ -56,10 +63,12 @@ __all__ = [
|
|
|
56
63
|
# -------------------------------------
|
|
57
64
|
|
|
58
65
|
|
|
66
|
+
@constexpr_function
|
|
59
67
|
def cdiv(x: int, y: int):
|
|
60
68
|
return (x + y - 1) // y
|
|
61
69
|
|
|
62
70
|
|
|
71
|
+
@constexpr_function
|
|
63
72
|
def next_power_of_2(n: int):
|
|
64
73
|
"""Return the smallest power of 2 greater than or equal to n"""
|
|
65
74
|
n -= 1
|
triton/_filecheck.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
import inspect
|
|
4
|
+
import subprocess
|
|
5
|
+
import tempfile
|
|
6
|
+
|
|
7
|
+
import triton
|
|
8
|
+
from triton.compiler import ASTSource, make_backend
|
|
9
|
+
from triton.backends.compiler import GPUTarget
|
|
10
|
+
from triton.experimental.gluon._runtime import GluonASTSource
|
|
11
|
+
from triton.runtime.jit import create_function_from_signature
|
|
12
|
+
from triton._C.libtriton import ir
|
|
13
|
+
|
|
14
|
+
# ===-----------------------------------------------------------------------===#
|
|
15
|
+
# filecheck_test
|
|
16
|
+
# ===-----------------------------------------------------------------------===#
|
|
17
|
+
|
|
18
|
+
# Stub target for testing the frontend.
|
|
19
|
+
stub_target = GPUTarget("cuda", 100, 32)
|
|
20
|
+
|
|
21
|
+
triton_dir = os.path.dirname(__file__)
|
|
22
|
+
filecheck_path = os.path.join(triton_dir, "FileCheck")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MatchError(ValueError):
|
|
26
|
+
|
|
27
|
+
def __init__(self, message, module_str):
|
|
28
|
+
super().__init__(message)
|
|
29
|
+
self.module_str = module_str
|
|
30
|
+
|
|
31
|
+
def __str__(self):
|
|
32
|
+
return f"{super().__str__()}\n{self.module_str}"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def run_filecheck(name, module_str, check_template):
|
|
36
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
|
37
|
+
temp_module = os.path.join(tempdir, "module")
|
|
38
|
+
with open(temp_module, "w") as temp:
|
|
39
|
+
temp.write(module_str)
|
|
40
|
+
|
|
41
|
+
temp_expected = os.path.join(tempdir, "expected")
|
|
42
|
+
with open(temp_expected, "w") as temp:
|
|
43
|
+
temp.write(check_template)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
subprocess.check_output(
|
|
47
|
+
[filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
|
|
48
|
+
stderr=subprocess.STDOUT)
|
|
49
|
+
except subprocess.CalledProcessError as error:
|
|
50
|
+
decoded = error.output.decode('unicode_escape')
|
|
51
|
+
raise ValueError(decoded)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
|
|
55
|
+
if "sanitize_overflow" not in kwargs:
|
|
56
|
+
kwargs = dict(kwargs)
|
|
57
|
+
kwargs["sanitize_overflow"] = False
|
|
58
|
+
backend = make_backend(target)
|
|
59
|
+
binder = create_function_from_signature(
|
|
60
|
+
kernel_fn.signature,
|
|
61
|
+
kernel_fn.params,
|
|
62
|
+
backend,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
bound_args, specialization, options = binder(*args, **kwargs)
|
|
66
|
+
options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
|
|
67
|
+
source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
|
|
68
|
+
src = source_cls(kernel_fn, signature, constexprs, attrs)
|
|
69
|
+
|
|
70
|
+
context = ir.context()
|
|
71
|
+
ir.load_dialects(context)
|
|
72
|
+
backend.load_dialects(context)
|
|
73
|
+
|
|
74
|
+
codegen_fns = backend.get_codegen_implementation(options)
|
|
75
|
+
module_map = backend.get_module_map()
|
|
76
|
+
module = src.make_ir(target, options, codegen_fns, module_map, context)
|
|
77
|
+
assert module.verify()
|
|
78
|
+
return module
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def run_filecheck_test(kernel_fn):
|
|
82
|
+
assert isinstance(kernel_fn, triton.runtime.JITFunction)
|
|
83
|
+
check_template = inspect.getsource(kernel_fn.fn)
|
|
84
|
+
if check_template is None:
|
|
85
|
+
raise ValueError("kernel function must have a docstring with FileCheck template")
|
|
86
|
+
mlir_module = run_parser(kernel_fn)
|
|
87
|
+
|
|
88
|
+
run_filecheck("placeholder", mlir_module.str_nodebug(), check_template)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def filecheck_test(fn):
|
|
92
|
+
|
|
93
|
+
@functools.wraps(fn)
|
|
94
|
+
def test_fn():
|
|
95
|
+
run_filecheck_test(fn)
|
|
96
|
+
|
|
97
|
+
return test_fn
|
triton/_internal_testing.py
CHANGED
|
@@ -4,11 +4,11 @@ import numpy as np
|
|
|
4
4
|
import torch
|
|
5
5
|
import triton
|
|
6
6
|
import triton.language as tl
|
|
7
|
-
from triton
|
|
7
|
+
from triton import knobs
|
|
8
|
+
from typing import Optional, Set, Union
|
|
8
9
|
import pytest
|
|
9
10
|
|
|
10
11
|
from numpy.random import RandomState
|
|
11
|
-
from typing import Optional, Union
|
|
12
12
|
from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
|
|
13
13
|
|
|
14
14
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
|
@@ -20,6 +20,7 @@ dtypes = integral_dtypes + float_dtypes
|
|
|
20
20
|
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
|
21
21
|
torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
|
|
22
22
|
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
|
|
23
|
+
tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def is_interpreter():
|
|
@@ -37,38 +38,58 @@ def is_cuda():
|
|
|
37
38
|
return False if target is None else target.backend == "cuda"
|
|
38
39
|
|
|
39
40
|
|
|
40
|
-
def
|
|
41
|
+
def is_ampere_or_newer():
|
|
42
|
+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 8
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def is_blackwell():
|
|
46
|
+
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def is_hopper_or_newer():
|
|
41
50
|
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
|
|
42
51
|
|
|
43
52
|
|
|
53
|
+
def is_hopper():
|
|
54
|
+
return is_cuda() and torch.cuda.get_device_capability()[0] == 9
|
|
55
|
+
|
|
56
|
+
|
|
44
57
|
def is_hip():
|
|
45
58
|
target = get_current_target()
|
|
46
59
|
return False if target is None else target.backend == "hip"
|
|
47
60
|
|
|
48
61
|
|
|
49
|
-
def
|
|
62
|
+
def is_hip_cdna2():
|
|
50
63
|
target = get_current_target()
|
|
51
|
-
|
|
52
|
-
return False
|
|
53
|
-
return target.arch == 'gfx90a'
|
|
64
|
+
return target is not None and target.backend == 'hip' and target.arch == 'gfx90a'
|
|
54
65
|
|
|
55
66
|
|
|
56
|
-
def
|
|
67
|
+
def is_hip_cdna3():
|
|
57
68
|
target = get_current_target()
|
|
58
|
-
|
|
59
|
-
return False
|
|
60
|
-
return target.arch in ('gfx940', 'gfx941', 'gfx942')
|
|
69
|
+
return target is not None and target.backend == 'hip' and target.arch == 'gfx942'
|
|
61
70
|
|
|
62
71
|
|
|
63
|
-
def
|
|
72
|
+
def is_hip_cdna4():
|
|
64
73
|
target = get_current_target()
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
74
|
+
return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def is_hip_gfx11():
|
|
78
|
+
target = get_current_target()
|
|
79
|
+
return target is not None and target.backend == 'hip' and 'gfx11' in target.arch
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def is_hip_gfx12():
|
|
83
|
+
target = get_current_target()
|
|
84
|
+
return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
|
|
68
85
|
|
|
69
86
|
|
|
70
87
|
def is_hip_cdna():
|
|
71
|
-
return
|
|
88
|
+
return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_hip_lds_size():
|
|
92
|
+
return 163840 if is_hip_cdna4() else 65536
|
|
72
93
|
|
|
73
94
|
|
|
74
95
|
def is_xpu():
|
|
@@ -131,7 +152,7 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc
|
|
|
131
152
|
|
|
132
153
|
|
|
133
154
|
def str_to_triton_dtype(x: str) -> tl.dtype:
|
|
134
|
-
return tl.str_to_ty(type_canonicalisation_dict[x])
|
|
155
|
+
return tl.str_to_ty(type_canonicalisation_dict[x], None)
|
|
135
156
|
|
|
136
157
|
|
|
137
158
|
def torch_dtype_name(dtype) -> str:
|
|
@@ -161,7 +182,7 @@ def supports_tma(byval_only=False):
|
|
|
161
182
|
return True
|
|
162
183
|
if not is_cuda():
|
|
163
184
|
return False
|
|
164
|
-
|
|
185
|
+
cuda_version = knobs.nvidia.ptxas.version
|
|
165
186
|
min_cuda_version = (12, 0) if byval_only else (12, 3)
|
|
166
187
|
cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
|
|
167
188
|
assert len(cuda_version_tuple) == 2, cuda_version_tuple
|
|
@@ -176,3 +197,59 @@ def tma_skip_msg(byval_only=False):
|
|
|
176
197
|
|
|
177
198
|
|
|
178
199
|
requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def default_alloc_fn(size: int, align: int, _):
|
|
203
|
+
return torch.empty(size, dtype=torch.int8, device="cuda")
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> torch.Tensor:
|
|
207
|
+
if isinstance(t, triton.runtime.jit.TensorWrapper):
|
|
208
|
+
return t.base
|
|
209
|
+
return t
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
|
|
213
|
+
from triton import knobs
|
|
214
|
+
|
|
215
|
+
if skipped_attr is None:
|
|
216
|
+
skipped_attr = set()
|
|
217
|
+
|
|
218
|
+
monkeypatch = pytest.MonkeyPatch()
|
|
219
|
+
|
|
220
|
+
knobs_map = {
|
|
221
|
+
name: knobset
|
|
222
|
+
for name, knobset in knobs.__dict__.items()
|
|
223
|
+
if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
# We store which variables we need to unset below in finally because
|
|
227
|
+
# monkeypatch doesn't appear to reset variables that were never set
|
|
228
|
+
# before the monkeypatch.delenv call below.
|
|
229
|
+
env_to_unset = []
|
|
230
|
+
prev_propagate_env = knobs.propagate_env
|
|
231
|
+
|
|
232
|
+
def fresh_function():
|
|
233
|
+
nonlocal env_to_unset
|
|
234
|
+
for name, knobset in knobs_map.items():
|
|
235
|
+
setattr(knobs, name, knobset.copy().reset())
|
|
236
|
+
for knob in knobset.knob_descriptors.values():
|
|
237
|
+
if knob.key in os.environ:
|
|
238
|
+
monkeypatch.delenv(knob.key, raising=False)
|
|
239
|
+
else:
|
|
240
|
+
env_to_unset.append(knob.key)
|
|
241
|
+
knobs.propagate_env = True
|
|
242
|
+
return knobs
|
|
243
|
+
|
|
244
|
+
def reset_function():
|
|
245
|
+
for name, knobset in knobs_map.items():
|
|
246
|
+
setattr(knobs, name, knobset)
|
|
247
|
+
# `undo` should be placed before `del os.environ`
|
|
248
|
+
# Otherwise, it may restore environment variables that monkeypatch deleted
|
|
249
|
+
monkeypatch.undo()
|
|
250
|
+
for k in env_to_unset:
|
|
251
|
+
if k in os.environ:
|
|
252
|
+
del os.environ[k]
|
|
253
|
+
knobs.propagate_env = prev_propagate_env
|
|
254
|
+
|
|
255
|
+
return fresh_function, reset_function
|
triton/_utils.py
CHANGED
|
@@ -1,35 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from functools import reduce
|
|
4
|
+
from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .language import core
|
|
8
|
+
IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
|
|
9
|
+
ObjPath = tuple[int, ...]
|
|
2
10
|
|
|
11
|
+
TRITON_MAX_TENSOR_NUMEL = 1048576
|
|
3
12
|
|
|
4
|
-
def get_iterable_path(iterable, path):
|
|
5
|
-
return reduce(lambda a, idx: a[idx], path, iterable)
|
|
6
13
|
|
|
14
|
+
def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
|
|
15
|
+
return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
|
|
7
16
|
|
|
8
|
-
|
|
17
|
+
|
|
18
|
+
def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
|
|
19
|
+
from .language import core
|
|
20
|
+
assert len(path) != 0
|
|
9
21
|
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
|
|
10
|
-
prev
|
|
22
|
+
assert isinstance(prev, core.tuple)
|
|
23
|
+
prev._setitem(path[-1], val)
|
|
11
24
|
|
|
12
25
|
|
|
13
|
-
def find_paths_if(iterable, pred):
|
|
26
|
+
def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
|
|
14
27
|
from .language import core
|
|
15
|
-
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
|
|
16
|
-
|
|
28
|
+
is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
|
|
29
|
+
# We need to use dict so that ordering is maintained, while set doesn't guarantee order
|
|
30
|
+
ret: dict[ObjPath, None] = {}
|
|
17
31
|
|
|
18
|
-
def _impl(
|
|
19
|
-
path = (path[0], ) if len(path) == 1 else tuple(path)
|
|
32
|
+
def _impl(path: tuple[int, ...], current: Any):
|
|
20
33
|
if is_iterable(current):
|
|
21
34
|
for idx, item in enumerate(current):
|
|
22
|
-
_impl(
|
|
35
|
+
_impl((*path, idx), item)
|
|
23
36
|
elif pred(path, current):
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
if is_iterable(iterable):
|
|
30
|
-
_impl(iterable, [])
|
|
31
|
-
elif pred(list(), iterable):
|
|
32
|
-
ret = {tuple(): None}
|
|
33
|
-
else:
|
|
34
|
-
ret = dict()
|
|
37
|
+
ret[path] = None
|
|
38
|
+
|
|
39
|
+
_impl((), iterable)
|
|
40
|
+
|
|
35
41
|
return list(ret.keys())
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def is_power_of_two(x):
|
|
45
|
+
return (x & (x - 1)) == 0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def validate_block_shape(shape: List[int]):
|
|
49
|
+
numel = 1
|
|
50
|
+
for i, d in enumerate(shape):
|
|
51
|
+
if not isinstance(d, int):
|
|
52
|
+
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
|
|
53
|
+
if not is_power_of_two(d):
|
|
54
|
+
raise ValueError(f"Shape element {i} must be a power of 2")
|
|
55
|
+
numel *= d
|
|
56
|
+
|
|
57
|
+
if numel > TRITON_MAX_TENSOR_NUMEL:
|
|
58
|
+
raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
|
|
59
|
+
return numel
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
type_canonicalisation_dict = {
|
|
63
|
+
# we canonicalise all bools to be unsigned:
|
|
64
|
+
"bool": "u1",
|
|
65
|
+
"int1": "u1",
|
|
66
|
+
"uint1": "u1",
|
|
67
|
+
"i1": "u1",
|
|
68
|
+
# floating-point dtypes:
|
|
69
|
+
"float8e4nv": "fp8e4nv",
|
|
70
|
+
"float8e5": "fp8e5",
|
|
71
|
+
"float8e4b15": "fp8e4b15",
|
|
72
|
+
"float8_e4m3fn": "fp8e4nv",
|
|
73
|
+
"float8e4b8": "fp8e4b8",
|
|
74
|
+
"float8_e4m3fnuz": "fp8e4b8",
|
|
75
|
+
"float8_e5m2": "fp8e5",
|
|
76
|
+
"float8e5b16": "fp8e5b16",
|
|
77
|
+
"float8_e5m2fnuz": "fp8e5b16",
|
|
78
|
+
"half": "fp16",
|
|
79
|
+
"float16": "fp16",
|
|
80
|
+
"bfloat16": "bf16",
|
|
81
|
+
"float": "fp32",
|
|
82
|
+
"float32": "fp32",
|
|
83
|
+
"double": "fp64",
|
|
84
|
+
"float64": "fp64",
|
|
85
|
+
# signed integers:
|
|
86
|
+
"int8": "i8",
|
|
87
|
+
"int16": "i16",
|
|
88
|
+
"int": "i32",
|
|
89
|
+
"int32": "i32",
|
|
90
|
+
"int64": "i64",
|
|
91
|
+
# unsigned integers:
|
|
92
|
+
"uint8": "u8",
|
|
93
|
+
"uint16": "u16",
|
|
94
|
+
"uint32": "u32",
|
|
95
|
+
"uint64": "u64",
|
|
96
|
+
"void": "void",
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
for v in list(type_canonicalisation_dict.values()):
|
|
100
|
+
type_canonicalisation_dict[v] = v
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def canonicalize_dtype(dtype):
|
|
104
|
+
dtype_str = str(dtype).split(".")[-1]
|
|
105
|
+
return type_canonicalisation_dict[dtype_str]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
BITWIDTH_DICT: Dict[str, int] = {
|
|
109
|
+
**{f"u{n}": n
|
|
110
|
+
for n in (1, 8, 16, 32, 64)},
|
|
111
|
+
**{f"i{n}": n
|
|
112
|
+
for n in (1, 8, 16, 32, 64)},
|
|
113
|
+
**{f"fp{n}": n
|
|
114
|
+
for n in (16, 32, 64)},
|
|
115
|
+
**{f"fp8{suffix}": 8
|
|
116
|
+
for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")},
|
|
117
|
+
"bf16": 16,
|
|
118
|
+
"void": 0,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
for k, v in type_canonicalisation_dict.items():
|
|
122
|
+
BITWIDTH_DICT[k] = BITWIDTH_DICT[v]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def get_primitive_bitwidth(dtype: str) -> int:
|
|
126
|
+
return BITWIDTH_DICT[dtype]
|
triton/backends/__init__.py
CHANGED
|
@@ -1,20 +1,22 @@
|
|
|
1
|
-
import
|
|
2
|
-
import importlib.util
|
|
1
|
+
import importlib
|
|
3
2
|
import inspect
|
|
3
|
+
import sys
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
+
from typing import Type, TypeVar, Union
|
|
6
|
+
from types import ModuleType
|
|
5
7
|
from .driver import DriverBase
|
|
6
8
|
from .compiler import BaseBackend
|
|
7
9
|
|
|
10
|
+
if sys.version_info >= (3, 10):
|
|
11
|
+
from importlib.metadata import entry_points
|
|
12
|
+
else:
|
|
13
|
+
from importlib_metadata import entry_points
|
|
8
14
|
|
|
9
|
-
|
|
10
|
-
spec = importlib.util.spec_from_file_location(name, path)
|
|
11
|
-
module = importlib.util.module_from_spec(spec)
|
|
12
|
-
spec.loader.exec_module(module)
|
|
13
|
-
return module
|
|
15
|
+
T = TypeVar("T", bound=Union[BaseBackend, DriverBase])
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
def _find_concrete_subclasses(module, base_class):
|
|
17
|
-
ret = []
|
|
18
|
+
def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]:
|
|
19
|
+
ret: list[Type[T]] = []
|
|
18
20
|
for attr_name in dir(module):
|
|
19
21
|
attr = getattr(module, attr_name)
|
|
20
22
|
if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr):
|
|
@@ -28,23 +30,18 @@ def _find_concrete_subclasses(module, base_class):
|
|
|
28
30
|
|
|
29
31
|
@dataclass(frozen=True)
|
|
30
32
|
class Backend:
|
|
31
|
-
compiler: BaseBackend
|
|
32
|
-
driver: DriverBase
|
|
33
|
+
compiler: Type[BaseBackend]
|
|
34
|
+
driver: Type[DriverBase]
|
|
33
35
|
|
|
34
36
|
|
|
35
|
-
def _discover_backends():
|
|
37
|
+
def _discover_backends() -> dict[str, Backend]:
|
|
36
38
|
backends = dict()
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
continue
|
|
43
|
-
compiler = _load_module(name, os.path.join(root, name, 'compiler.py'))
|
|
44
|
-
driver = _load_module(name, os.path.join(root, name, 'driver.py'))
|
|
45
|
-
backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
|
|
46
|
-
_find_concrete_subclasses(driver, DriverBase))
|
|
39
|
+
for ep in entry_points().select(group="triton.backends"):
|
|
40
|
+
compiler = importlib.import_module(f"{ep.value}.compiler")
|
|
41
|
+
driver = importlib.import_module(f"{ep.value}.driver")
|
|
42
|
+
backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore
|
|
43
|
+
_find_concrete_subclasses(driver, DriverBase)) # type: ignore
|
|
47
44
|
return backends
|
|
48
45
|
|
|
49
46
|
|
|
50
|
-
backends = _discover_backends()
|
|
47
|
+
backends: dict[str, Backend] = _discover_backends()
|
|
File without changes
|