triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/runtime/tcc/libtcc.dll
CHANGED
|
Binary file
|
triton/runtime/tcc/tcc.exe
CHANGED
|
Binary file
|
triton/testing.py
CHANGED
|
@@ -95,7 +95,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
|
|
|
95
95
|
end_event.record()
|
|
96
96
|
torch.cuda.synchronize()
|
|
97
97
|
estimate_ms = start_event.elapsed_time(end_event) / 5
|
|
98
|
-
|
|
98
|
+
# Rewrite to avoid possible division by 0 issues with fast benchmarks
|
|
99
|
+
if estimate_ms == 0:
|
|
100
|
+
n_repeat = 1000
|
|
101
|
+
else:
|
|
102
|
+
n_repeat = max(1, int(rep / estimate_ms))
|
|
99
103
|
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
|
|
100
104
|
# host overhead
|
|
101
105
|
g = torch.cuda.CUDAGraph()
|
|
@@ -383,18 +387,18 @@ class Mark:
|
|
|
383
387
|
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
|
384
388
|
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
|
385
389
|
result_dfs = []
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
html.write("<html><body>\n")
|
|
391
|
-
for bench in benchmarks:
|
|
392
|
-
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
|
|
390
|
+
try:
|
|
391
|
+
for bench in benchmarks:
|
|
392
|
+
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
|
|
393
|
+
finally:
|
|
393
394
|
if save_path:
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
395
|
+
# Create directory if it doesn't exist
|
|
396
|
+
os.makedirs(save_path, exist_ok=True)
|
|
397
|
+
with open(os.path.join(save_path, "results.html"), "w") as html:
|
|
398
|
+
html.write("<html><body>\n")
|
|
399
|
+
for bench in benchmarks[:len(result_dfs)]:
|
|
400
|
+
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
|
401
|
+
html.write("</body></html>\n")
|
|
398
402
|
if return_df:
|
|
399
403
|
if has_single_bench:
|
|
400
404
|
return result_dfs[0]
|
triton/tools/compile.py
CHANGED
|
@@ -3,12 +3,29 @@ import hashlib
|
|
|
3
3
|
import importlib.util
|
|
4
4
|
import sys
|
|
5
5
|
from argparse import ArgumentParser
|
|
6
|
+
from dataclasses import dataclass
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from typing import List
|
|
8
9
|
|
|
9
10
|
import triton
|
|
10
11
|
import triton.backends
|
|
11
|
-
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class CompileArgs:
|
|
16
|
+
'''
|
|
17
|
+
A class to contain arguments from command-line parser.
|
|
18
|
+
'''
|
|
19
|
+
path: str = ''
|
|
20
|
+
kernel_name: str = ''
|
|
21
|
+
signature: str = ''
|
|
22
|
+
grid: str = ''
|
|
23
|
+
target: str | None = None
|
|
24
|
+
num_warps: int = 1
|
|
25
|
+
num_stages: int = 3
|
|
26
|
+
out_name: str | None = None
|
|
27
|
+
out_path: Path | None = None
|
|
28
|
+
|
|
12
29
|
|
|
13
30
|
desc = """
|
|
14
31
|
Triton ahead-of-time compiler:
|
|
@@ -36,14 +53,18 @@ NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed
|
|
|
36
53
|
used to run this `compile.py` script
|
|
37
54
|
"""
|
|
38
55
|
|
|
39
|
-
if __name__ == "__main__":
|
|
40
56
|
|
|
57
|
+
def main():
|
|
41
58
|
# command-line arguments
|
|
42
59
|
parser = ArgumentParser(description=desc)
|
|
43
60
|
parser.add_argument("path",
|
|
44
61
|
help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
|
45
62
|
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
|
|
46
63
|
required=True)
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--target", "-t", type=str, default=None,
|
|
66
|
+
help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
|
|
67
|
+
"e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
|
|
47
68
|
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
|
48
69
|
parser.add_argument("--num-stages", "-ns", type=int, default=3,
|
|
49
70
|
help="Number of stages (meta-parameter of the kernel)")
|
|
@@ -51,8 +72,12 @@ if __name__ == "__main__":
|
|
|
51
72
|
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
|
52
73
|
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
|
53
74
|
parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
|
|
54
|
-
|
|
75
|
+
cli_args = parser.parse_args()
|
|
76
|
+
args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
|
|
77
|
+
compile_kernel(args)
|
|
55
78
|
|
|
79
|
+
|
|
80
|
+
def compile_kernel(args: CompileArgs):
|
|
56
81
|
out_name = args.out_name if args.out_name else args.kernel_name
|
|
57
82
|
out_path = args.out_path if args.out_path else Path(out_name)
|
|
58
83
|
|
|
@@ -108,10 +133,18 @@ if __name__ == "__main__":
|
|
|
108
133
|
assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
|
|
109
134
|
attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
|
|
110
135
|
src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
136
|
+
|
|
137
|
+
target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
|
|
138
|
+
if args.target else triton.runtime.driver.active.get_current_target()
|
|
139
|
+
backend = triton.compiler.make_backend(target)
|
|
140
|
+
kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
|
|
141
|
+
options = backend.parse_options(kwargs)
|
|
142
|
+
ccinfo = triton.compile(src, target=target, options=options.__dict__)
|
|
143
|
+
|
|
144
|
+
if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
|
|
114
145
|
raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
|
|
146
|
+
if ccinfo.metadata.profile_scratch_size > 0:
|
|
147
|
+
raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
|
|
115
148
|
|
|
116
149
|
arg_names = []
|
|
117
150
|
arg_types = []
|
|
@@ -136,8 +169,12 @@ if __name__ == "__main__":
|
|
|
136
169
|
if hints.get((i, ), None) == 16:
|
|
137
170
|
suffix += 'd'
|
|
138
171
|
func_name = '_'.join([out_name, sig_hash, suffix])
|
|
139
|
-
asm = ccinfo.asm[
|
|
172
|
+
asm = ccinfo.asm[backend.binary_ext] # store binary data once
|
|
173
|
+
|
|
140
174
|
hex_ = str(binascii.hexlify(asm))[2:-1]
|
|
175
|
+
|
|
176
|
+
ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
|
|
177
|
+
|
|
141
178
|
params = {
|
|
142
179
|
"kernel_name": func_name,
|
|
143
180
|
"triton_kernel_name": args.kernel_name,
|
|
@@ -145,18 +182,29 @@ if __name__ == "__main__":
|
|
|
145
182
|
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
|
|
146
183
|
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
|
|
147
184
|
"full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
|
|
148
|
-
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]),
|
|
149
|
-
"num_args": len(arg_names_not_1) +
|
|
185
|
+
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
|
|
186
|
+
"num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
|
|
150
187
|
"kernel_docstring": doc_string,
|
|
151
188
|
"shared": ccinfo.metadata.shared,
|
|
152
189
|
"num_warps": args.num_warps,
|
|
153
|
-
"algo_info":
|
|
190
|
+
"algo_info": "_".join([const_sig, meta_sig]),
|
|
154
191
|
"gridX": grid[0],
|
|
155
192
|
"gridY": grid[1],
|
|
156
193
|
"gridZ": grid[2],
|
|
157
194
|
"_placeholder": "",
|
|
158
195
|
}
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
196
|
+
output_files = []
|
|
197
|
+
backend_name = target.backend
|
|
198
|
+
template_dir = Path(__file__).parent / "extra" / backend_name
|
|
199
|
+
for template_path in template_dir.glob('compile.*'):
|
|
200
|
+
ext = template_path.suffix
|
|
201
|
+
output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
|
|
202
|
+
with output_file.open("w") as fp:
|
|
203
|
+
fp.write(template_path.read_text().format(**params))
|
|
204
|
+
output_files.append(output_file)
|
|
205
|
+
|
|
206
|
+
return func_name, output_files
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if __name__ == "__main__":
|
|
210
|
+
main()
|
triton/tools/disasm.py
CHANGED
|
@@ -75,14 +75,13 @@ def get_sass(cubin_asm, fun=None):
|
|
|
75
75
|
return sass
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
@functools.lru_cache()
|
|
79
78
|
def path_to_cuobjdump():
|
|
80
|
-
from triton
|
|
81
|
-
return
|
|
79
|
+
from triton import knobs
|
|
80
|
+
return knobs.nvidia.cuobjdump.path
|
|
82
81
|
|
|
83
82
|
|
|
84
83
|
def extract(file_path, fun):
|
|
85
|
-
cuobjdump
|
|
84
|
+
cuobjdump = path_to_cuobjdump()
|
|
86
85
|
if fun is None:
|
|
87
86
|
sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
|
|
88
87
|
else:
|
|
@@ -61,6 +61,7 @@ CUresult {kernel_name}(CUstream stream, {signature}) {{
|
|
|
61
61
|
unsigned int gY = {gridY};
|
|
62
62
|
unsigned int gZ = {gridZ};
|
|
63
63
|
CUdeviceptr global_scratch = 0;
|
|
64
|
+
CUdeviceptr profile_scratch = 0;
|
|
64
65
|
void *args[{num_args}] = {{ {arg_pointers} }};
|
|
65
66
|
// TODO: shared memory
|
|
66
67
|
if(gX * gY * gZ > 0)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
// SPDX-License-Identifier: MIT
|
|
2
|
+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
3
|
+
|
|
4
|
+
/* clang-format off */
|
|
5
|
+
#include <stdio.h>
|
|
6
|
+
#include <stdint.h>
|
|
7
|
+
#include <inttypes.h>
|
|
8
|
+
#include <string.h>
|
|
9
|
+
#include <hip/hip_runtime.h>
|
|
10
|
+
|
|
11
|
+
// helpers to check for hip errors
|
|
12
|
+
#define HIP_CHECK(ans) {{\
|
|
13
|
+
gpuAssert((ans), __FILE__, __LINE__);\
|
|
14
|
+
}}\
|
|
15
|
+
|
|
16
|
+
static inline void gpuAssert(hipError_t code, const char *file, int line) {{
|
|
17
|
+
if (code != hipSuccess) {{
|
|
18
|
+
const char *prefix = "Triton Error [HIP]: ";
|
|
19
|
+
const char *str;
|
|
20
|
+
hipDrvGetErrorString(code, &str);
|
|
21
|
+
char err[1024] = {{0}};
|
|
22
|
+
strcat(err, prefix);
|
|
23
|
+
strcat(err, str);
|
|
24
|
+
printf("%s\\n", err);
|
|
25
|
+
exit(code);
|
|
26
|
+
}}
|
|
27
|
+
}}
|
|
28
|
+
|
|
29
|
+
// globals
|
|
30
|
+
#define HSACO_NAME {kernel_name}_hsaco
|
|
31
|
+
hipModule_t {kernel_name}_mod = nullptr;
|
|
32
|
+
hipFunction_t {kernel_name}_func = nullptr;
|
|
33
|
+
unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
void unload_{kernel_name}(void) {{
|
|
37
|
+
HIP_CHECK(hipModuleUnload({kernel_name}_mod));
|
|
38
|
+
}}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
void load_{kernel_name}() {{
|
|
42
|
+
int dev = 0;
|
|
43
|
+
void *bin = (void *)&HSACO_NAME;
|
|
44
|
+
int shared = {shared};
|
|
45
|
+
HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
|
|
46
|
+
HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
|
|
47
|
+
}}
|
|
48
|
+
|
|
49
|
+
/*
|
|
50
|
+
{kernel_docstring}
|
|
51
|
+
*/
|
|
52
|
+
hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
|
|
53
|
+
if ({kernel_name}_func == nullptr)
|
|
54
|
+
load_{kernel_name}();
|
|
55
|
+
unsigned int gX = {gridX};
|
|
56
|
+
unsigned int gY = {gridY};
|
|
57
|
+
unsigned int gZ = {gridZ};
|
|
58
|
+
hipDeviceptr_t global_scratch = 0;
|
|
59
|
+
hipDeviceptr_t profile_scratch = 0;
|
|
60
|
+
void *args[{num_args}] = {{ {arg_pointers} }};
|
|
61
|
+
// TODO: shared memory
|
|
62
|
+
if(gX * gY * gZ > 0)
|
|
63
|
+
return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
|
|
64
|
+
else
|
|
65
|
+
return hipErrorInvalidValue;
|
|
66
|
+
}}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
// SPDX-License-Identifier: MIT
|
|
2
|
+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
3
|
+
|
|
4
|
+
#pragma once
|
|
5
|
+
|
|
6
|
+
#include <hip/hip_runtime.h>
|
|
7
|
+
#include <inttypes.h>
|
|
8
|
+
#include <stdint.h>
|
|
9
|
+
#include <stdio.h>
|
|
10
|
+
|
|
11
|
+
void unload_{kernel_name}(void);
|
|
12
|
+
void load_{kernel_name}(void);
|
|
13
|
+
hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import triton
|
|
2
|
+
import triton.language as tl
|
|
3
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
4
|
+
|
|
5
|
+
# fmt: off
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_ragged_descriptor(T, block_shape, ragged_dim=0):
|
|
9
|
+
"""
|
|
10
|
+
Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
|
|
11
|
+
which behaves like a concatenation (along the first axis) of subarrays
|
|
12
|
+
of potentially unequal size.
|
|
13
|
+
|
|
14
|
+
The load_ragged and store_ragged device functions can be used to read
|
|
15
|
+
and write from subarrays T[batch_offset : batch_offset + batch_size]
|
|
16
|
+
with hardware bounds-checking preventing any sort of leakage outside
|
|
17
|
+
the subarray.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
block_shape = list(block_shape)
|
|
21
|
+
tensor_shape = list(T.shape)
|
|
22
|
+
rank = len(tensor_shape)
|
|
23
|
+
|
|
24
|
+
if ragged_dim < 0:
|
|
25
|
+
ragged_dim += rank
|
|
26
|
+
|
|
27
|
+
assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
|
|
28
|
+
assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
|
|
29
|
+
|
|
30
|
+
assert len(block_shape) == rank, "block shape must have same length as tensor shape"
|
|
31
|
+
|
|
32
|
+
max_int = 0x7fff0000
|
|
33
|
+
billion = 0x40000000 # == 2**30
|
|
34
|
+
|
|
35
|
+
assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
|
|
36
|
+
tensor_shape[ragged_dim] = billion
|
|
37
|
+
ragged_stride = T.stride(ragged_dim)
|
|
38
|
+
|
|
39
|
+
# we prepend an extra two dimensions and rely on the fact that pointers
|
|
40
|
+
# have 64-bit wraparound semantics:
|
|
41
|
+
tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
|
|
42
|
+
tma_shape = [max_int, max_int] + tensor_shape
|
|
43
|
+
box_shape = [1, 1] + block_shape
|
|
44
|
+
|
|
45
|
+
return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@triton.jit
|
|
49
|
+
def to_ragged_indices(batch_offset, batch_size, row):
|
|
50
|
+
"""
|
|
51
|
+
Helper function for load_ragged and store_ragged.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
billion = 0x40000000 # == 2**30
|
|
55
|
+
x = billion - batch_size + row
|
|
56
|
+
y = batch_offset + batch_size
|
|
57
|
+
|
|
58
|
+
return billion, y, x
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@triton.jit
|
|
62
|
+
def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
|
|
63
|
+
"""
|
|
64
|
+
Read from a subarray T[batch_offset : batch_offset + batch_size] with
|
|
65
|
+
hardware bounds-checking, where reading outside the subarray gives zeros.
|
|
66
|
+
|
|
67
|
+
Coords should be an appropriately-sized list of integers, just like in
|
|
68
|
+
TMA.load().
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
|
|
72
|
+
|
|
73
|
+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
|
|
74
|
+
data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
|
|
75
|
+
data = tl.reshape(data, data.shape[2:])
|
|
76
|
+
return data
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@triton.jit
|
|
80
|
+
def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
|
|
81
|
+
"""
|
|
82
|
+
Write to a subarray T[batch_offset : batch_offset + batch_size] with
|
|
83
|
+
hardware bounds-checking, where writes outside the subarray are masked
|
|
84
|
+
correctly.
|
|
85
|
+
|
|
86
|
+
Coords should be an appropriately-sized list of integers, just like in
|
|
87
|
+
TMA.store().
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
|
|
91
|
+
data = tl.reshape(data, [1, 1] + data.shape)
|
|
92
|
+
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Any
|
|
3
|
+
from triton._utils import validate_block_shape
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class TensorDescriptor:
|
|
8
|
+
base: Any
|
|
9
|
+
shape: List[int]
|
|
10
|
+
strides: List[int]
|
|
11
|
+
block_shape: List[int]
|
|
12
|
+
padding: str = "zero"
|
|
13
|
+
|
|
14
|
+
def __post_init__(self):
|
|
15
|
+
rank = len(self.shape)
|
|
16
|
+
assert len(self.strides) == rank, f"rank mismatch: {self}"
|
|
17
|
+
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
|
|
18
|
+
assert rank > 0, "rank must not be zero"
|
|
19
|
+
assert rank <= 5, "rank cannot be more than 5"
|
|
20
|
+
ty = type(self.base)
|
|
21
|
+
if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
|
|
22
|
+
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
|
|
23
|
+
validate_block_shape(self.block_shape)
|
|
24
|
+
elem_bytes = self.base.dtype.itemsize
|
|
25
|
+
for stride in self.strides[:-1]:
|
|
26
|
+
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
|
|
27
|
+
assert self.strides[-1] == 1, "Last dimension must be contiguous"
|
|
28
|
+
assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
|
|
29
|
+
if self.padding == "nan":
|
|
30
|
+
assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def from_tensor(tensor: Any, block_shape: List[int], padding="zero"):
|
|
34
|
+
return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)
|