triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/runtime/tcc/libtcc.dll
CHANGED
|
Binary file
|
triton/runtime/tcc/tcc.exe
CHANGED
|
Binary file
|
triton/tools/compile.py
CHANGED
|
@@ -3,12 +3,29 @@ import hashlib
|
|
|
3
3
|
import importlib.util
|
|
4
4
|
import sys
|
|
5
5
|
from argparse import ArgumentParser
|
|
6
|
+
from dataclasses import dataclass
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from typing import List
|
|
8
9
|
|
|
9
10
|
import triton
|
|
10
11
|
import triton.backends
|
|
11
|
-
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class CompileArgs:
|
|
16
|
+
'''
|
|
17
|
+
A class to contain arguments from command-line parser.
|
|
18
|
+
'''
|
|
19
|
+
path: str = ''
|
|
20
|
+
kernel_name: str = ''
|
|
21
|
+
signature: str = ''
|
|
22
|
+
grid: str = ''
|
|
23
|
+
target: str | None = None
|
|
24
|
+
num_warps: int = 1
|
|
25
|
+
num_stages: int = 3
|
|
26
|
+
out_name: str | None = None
|
|
27
|
+
out_path: Path | None = None
|
|
28
|
+
|
|
12
29
|
|
|
13
30
|
desc = """
|
|
14
31
|
Triton ahead-of-time compiler:
|
|
@@ -36,14 +53,18 @@ NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed
|
|
|
36
53
|
used to run this `compile.py` script
|
|
37
54
|
"""
|
|
38
55
|
|
|
39
|
-
if __name__ == "__main__":
|
|
40
56
|
|
|
57
|
+
def main():
|
|
41
58
|
# command-line arguments
|
|
42
59
|
parser = ArgumentParser(description=desc)
|
|
43
60
|
parser.add_argument("path",
|
|
44
61
|
help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
|
45
62
|
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
|
|
46
63
|
required=True)
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--target", "-t", type=str, default=None,
|
|
66
|
+
help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
|
|
67
|
+
"e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
|
|
47
68
|
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
|
48
69
|
parser.add_argument("--num-stages", "-ns", type=int, default=3,
|
|
49
70
|
help="Number of stages (meta-parameter of the kernel)")
|
|
@@ -51,8 +72,12 @@ if __name__ == "__main__":
|
|
|
51
72
|
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
|
52
73
|
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
|
53
74
|
parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
|
|
54
|
-
|
|
75
|
+
cli_args = parser.parse_args()
|
|
76
|
+
args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
|
|
77
|
+
compile_kernel(args)
|
|
55
78
|
|
|
79
|
+
|
|
80
|
+
def compile_kernel(args: CompileArgs):
|
|
56
81
|
out_name = args.out_name if args.out_name else args.kernel_name
|
|
57
82
|
out_path = args.out_path if args.out_path else Path(out_name)
|
|
58
83
|
|
|
@@ -108,10 +133,18 @@ if __name__ == "__main__":
|
|
|
108
133
|
assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
|
|
109
134
|
attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
|
|
110
135
|
src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
136
|
+
|
|
137
|
+
target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
|
|
138
|
+
if args.target else triton.runtime.driver.active.get_current_target()
|
|
139
|
+
backend = triton.compiler.make_backend(target)
|
|
140
|
+
kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
|
|
141
|
+
options = backend.parse_options(kwargs)
|
|
142
|
+
ccinfo = triton.compile(src, target=target, options=options.__dict__)
|
|
143
|
+
|
|
144
|
+
if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
|
|
114
145
|
raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
|
|
146
|
+
if ccinfo.metadata.profile_scratch_size > 0:
|
|
147
|
+
raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
|
|
115
148
|
|
|
116
149
|
arg_names = []
|
|
117
150
|
arg_types = []
|
|
@@ -136,8 +169,12 @@ if __name__ == "__main__":
|
|
|
136
169
|
if hints.get((i, ), None) == 16:
|
|
137
170
|
suffix += 'd'
|
|
138
171
|
func_name = '_'.join([out_name, sig_hash, suffix])
|
|
139
|
-
asm = ccinfo.asm[
|
|
172
|
+
asm = ccinfo.asm[backend.binary_ext] # store binary data once
|
|
173
|
+
|
|
140
174
|
hex_ = str(binascii.hexlify(asm))[2:-1]
|
|
175
|
+
|
|
176
|
+
ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
|
|
177
|
+
|
|
141
178
|
params = {
|
|
142
179
|
"kernel_name": func_name,
|
|
143
180
|
"triton_kernel_name": args.kernel_name,
|
|
@@ -145,18 +182,29 @@ if __name__ == "__main__":
|
|
|
145
182
|
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
|
|
146
183
|
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
|
|
147
184
|
"full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
|
|
148
|
-
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]),
|
|
149
|
-
"num_args": len(arg_names_not_1) +
|
|
185
|
+
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
|
|
186
|
+
"num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
|
|
150
187
|
"kernel_docstring": doc_string,
|
|
151
188
|
"shared": ccinfo.metadata.shared,
|
|
152
189
|
"num_warps": args.num_warps,
|
|
153
|
-
"algo_info":
|
|
190
|
+
"algo_info": "_".join([const_sig, meta_sig]),
|
|
154
191
|
"gridX": grid[0],
|
|
155
192
|
"gridY": grid[1],
|
|
156
193
|
"gridZ": grid[2],
|
|
157
194
|
"_placeholder": "",
|
|
158
195
|
}
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
196
|
+
output_files = []
|
|
197
|
+
backend_name = target.backend
|
|
198
|
+
template_dir = Path(__file__).parent / "extra" / backend_name
|
|
199
|
+
for template_path in template_dir.glob('compile.*'):
|
|
200
|
+
ext = template_path.suffix
|
|
201
|
+
output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
|
|
202
|
+
with output_file.open("w") as fp:
|
|
203
|
+
fp.write(template_path.read_text().format(**params))
|
|
204
|
+
output_files.append(output_file)
|
|
205
|
+
|
|
206
|
+
return func_name, output_files
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if __name__ == "__main__":
|
|
210
|
+
main()
|
|
@@ -61,6 +61,7 @@ CUresult {kernel_name}(CUstream stream, {signature}) {{
|
|
|
61
61
|
unsigned int gY = {gridY};
|
|
62
62
|
unsigned int gZ = {gridZ};
|
|
63
63
|
CUdeviceptr global_scratch = 0;
|
|
64
|
+
CUdeviceptr profile_scratch = 0;
|
|
64
65
|
void *args[{num_args}] = {{ {arg_pointers} }};
|
|
65
66
|
// TODO: shared memory
|
|
66
67
|
if(gX * gY * gZ > 0)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
// SPDX-License-Identifier: MIT
|
|
2
|
+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
3
|
+
|
|
4
|
+
/* clang-format off */
|
|
5
|
+
#include <stdio.h>
|
|
6
|
+
#include <stdint.h>
|
|
7
|
+
#include <inttypes.h>
|
|
8
|
+
#include <string.h>
|
|
9
|
+
#include <hip/hip_runtime.h>
|
|
10
|
+
|
|
11
|
+
// helpers to check for hip errors
|
|
12
|
+
#define HIP_CHECK(ans) {{\
|
|
13
|
+
gpuAssert((ans), __FILE__, __LINE__);\
|
|
14
|
+
}}\
|
|
15
|
+
|
|
16
|
+
static inline void gpuAssert(hipError_t code, const char *file, int line) {{
|
|
17
|
+
if (code != hipSuccess) {{
|
|
18
|
+
const char *prefix = "Triton Error [HIP]: ";
|
|
19
|
+
const char *str;
|
|
20
|
+
hipDrvGetErrorString(code, &str);
|
|
21
|
+
char err[1024] = {{0}};
|
|
22
|
+
strcat(err, prefix);
|
|
23
|
+
strcat(err, str);
|
|
24
|
+
printf("%s\\n", err);
|
|
25
|
+
exit(code);
|
|
26
|
+
}}
|
|
27
|
+
}}
|
|
28
|
+
|
|
29
|
+
// globals
|
|
30
|
+
#define HSACO_NAME {kernel_name}_hsaco
|
|
31
|
+
hipModule_t {kernel_name}_mod = nullptr;
|
|
32
|
+
hipFunction_t {kernel_name}_func = nullptr;
|
|
33
|
+
unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
void unload_{kernel_name}(void) {{
|
|
37
|
+
HIP_CHECK(hipModuleUnload({kernel_name}_mod));
|
|
38
|
+
}}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
void load_{kernel_name}() {{
|
|
42
|
+
int dev = 0;
|
|
43
|
+
void *bin = (void *)&HSACO_NAME;
|
|
44
|
+
int shared = {shared};
|
|
45
|
+
HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
|
|
46
|
+
HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
|
|
47
|
+
}}
|
|
48
|
+
|
|
49
|
+
/*
|
|
50
|
+
{kernel_docstring}
|
|
51
|
+
*/
|
|
52
|
+
hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
|
|
53
|
+
if ({kernel_name}_func == nullptr)
|
|
54
|
+
load_{kernel_name}();
|
|
55
|
+
unsigned int gX = {gridX};
|
|
56
|
+
unsigned int gY = {gridY};
|
|
57
|
+
unsigned int gZ = {gridZ};
|
|
58
|
+
hipDeviceptr_t global_scratch = 0;
|
|
59
|
+
hipDeviceptr_t profile_scratch = 0;
|
|
60
|
+
void *args[{num_args}] = {{ {arg_pointers} }};
|
|
61
|
+
// TODO: shared memory
|
|
62
|
+
if(gX * gY * gZ > 0)
|
|
63
|
+
return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
|
|
64
|
+
else
|
|
65
|
+
return hipErrorInvalidValue;
|
|
66
|
+
}}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
// SPDX-License-Identifier: MIT
|
|
2
|
+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
3
|
+
|
|
4
|
+
#pragma once
|
|
5
|
+
|
|
6
|
+
#include <hip/hip_runtime.h>
|
|
7
|
+
#include <inttypes.h>
|
|
8
|
+
#include <stdint.h>
|
|
9
|
+
#include <stdio.h>
|
|
10
|
+
|
|
11
|
+
void unload_{kernel_name}(void);
|
|
12
|
+
void load_{kernel_name}(void);
|
|
13
|
+
hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import triton
|
|
2
|
+
import triton.language as tl
|
|
3
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
4
|
+
|
|
5
|
+
# fmt: off
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_ragged_descriptor(T, block_shape, ragged_dim=0):
|
|
9
|
+
"""
|
|
10
|
+
Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
|
|
11
|
+
which behaves like a concatenation (along the first axis) of subarrays
|
|
12
|
+
of potentially unequal size.
|
|
13
|
+
|
|
14
|
+
The load_ragged and store_ragged device functions can be used to read
|
|
15
|
+
and write from subarrays T[batch_offset : batch_offset + batch_size]
|
|
16
|
+
with hardware bounds-checking preventing any sort of leakage outside
|
|
17
|
+
the subarray.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
block_shape = list(block_shape)
|
|
21
|
+
tensor_shape = list(T.shape)
|
|
22
|
+
rank = len(tensor_shape)
|
|
23
|
+
|
|
24
|
+
if ragged_dim < 0:
|
|
25
|
+
ragged_dim += rank
|
|
26
|
+
|
|
27
|
+
assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
|
|
28
|
+
assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
|
|
29
|
+
|
|
30
|
+
assert len(block_shape) == rank, "block shape must have same length as tensor shape"
|
|
31
|
+
|
|
32
|
+
max_int = 0x7fff0000
|
|
33
|
+
billion = 0x40000000 # == 2**30
|
|
34
|
+
|
|
35
|
+
assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
|
|
36
|
+
tensor_shape[ragged_dim] = billion
|
|
37
|
+
ragged_stride = T.stride(ragged_dim)
|
|
38
|
+
|
|
39
|
+
# we prepend an extra two dimensions and rely on the fact that pointers
|
|
40
|
+
# have 64-bit wraparound semantics:
|
|
41
|
+
tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
|
|
42
|
+
tma_shape = [max_int, max_int] + tensor_shape
|
|
43
|
+
box_shape = [1, 1] + block_shape
|
|
44
|
+
|
|
45
|
+
return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@triton.jit
|
|
49
|
+
def to_ragged_indices(batch_offset, batch_size, row):
|
|
50
|
+
"""
|
|
51
|
+
Helper function for load_ragged and store_ragged.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
billion = 0x40000000 # == 2**30
|
|
55
|
+
x = billion - batch_size + row
|
|
56
|
+
y = batch_offset + batch_size
|
|
57
|
+
|
|
58
|
+
return billion, y, x
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@triton.jit
|
|
62
|
+
def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
|
|
63
|
+
"""
|
|
64
|
+
Read from a subarray T[batch_offset : batch_offset + batch_size] with
|
|
65
|
+
hardware bounds-checking, where reading outside the subarray gives zeros.
|
|
66
|
+
|
|
67
|
+
Coords should be an appropriately-sized list of integers, just like in
|
|
68
|
+
TMA.load().
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
|
|
72
|
+
|
|
73
|
+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
|
|
74
|
+
data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
|
|
75
|
+
data = tl.reshape(data, data.shape[2:])
|
|
76
|
+
return data
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@triton.jit
|
|
80
|
+
def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
|
|
81
|
+
"""
|
|
82
|
+
Write to a subarray T[batch_offset : batch_offset + batch_size] with
|
|
83
|
+
hardware bounds-checking, where writes outside the subarray are masked
|
|
84
|
+
correctly.
|
|
85
|
+
|
|
86
|
+
Coords should be an appropriately-sized list of integers, just like in
|
|
87
|
+
TMA.store().
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
|
|
91
|
+
data = tl.reshape(data, [1, 1] + data.shape)
|
|
92
|
+
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
|
|
@@ -9,6 +9,7 @@ class TensorDescriptor:
|
|
|
9
9
|
shape: List[int]
|
|
10
10
|
strides: List[int]
|
|
11
11
|
block_shape: List[int]
|
|
12
|
+
padding: str = "zero"
|
|
12
13
|
|
|
13
14
|
def __post_init__(self):
|
|
14
15
|
rank = len(self.shape)
|
|
@@ -17,20 +18,17 @@ class TensorDescriptor:
|
|
|
17
18
|
assert rank > 0, "rank must not be zero"
|
|
18
19
|
assert rank <= 5, "rank cannot be more than 5"
|
|
19
20
|
ty = type(self.base)
|
|
20
|
-
|
|
21
|
-
if type_name not in ("torch.FakeTensor", "torch.FunctionalTensor"):
|
|
21
|
+
if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
|
|
22
22
|
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
|
|
23
23
|
validate_block_shape(self.block_shape)
|
|
24
24
|
elem_bytes = self.base.dtype.itemsize
|
|
25
25
|
for stride in self.strides[:-1]:
|
|
26
26
|
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
|
|
27
27
|
assert self.strides[-1] == 1, "Last dimension must be contiguous"
|
|
28
|
+
assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
|
|
29
|
+
if self.padding == "nan":
|
|
30
|
+
assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
|
|
28
31
|
|
|
29
32
|
@staticmethod
|
|
30
|
-
def from_tensor(tensor: Any, block_shape: List[int]):
|
|
31
|
-
return TensorDescriptor(
|
|
32
|
-
tensor,
|
|
33
|
-
tensor.shape,
|
|
34
|
-
tensor.stride(),
|
|
35
|
-
block_shape,
|
|
36
|
-
)
|
|
33
|
+
def from_tensor(tensor: Any, block_shape: List[int], padding="zero"):
|
|
34
|
+
return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)
|
triton/windows_utils.py
CHANGED
|
@@ -54,14 +54,11 @@ def max_version(
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def check_msvc(msvc_base_path: Path, version: str) -> bool:
|
|
57
|
-
return all(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
|
|
63
|
-
]
|
|
64
|
-
)
|
|
57
|
+
return all(x.exists() for x in [
|
|
58
|
+
msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
|
|
59
|
+
msvc_base_path / version / "include" / "vcruntime.h",
|
|
60
|
+
msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
|
|
61
|
+
])
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
|
|
@@ -72,20 +69,16 @@ def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
|
|
|
72
69
|
|
|
73
70
|
version = os.getenv("VCToolsVersion")
|
|
74
71
|
if not check_msvc(msvc_base_path, version):
|
|
75
|
-
warnings.warn(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
"but this MSVC installation is incomplete."
|
|
79
|
-
)
|
|
72
|
+
warnings.warn(f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
|
|
73
|
+
f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
|
|
74
|
+
"but this MSVC installation is incomplete.")
|
|
80
75
|
return None, None
|
|
81
76
|
|
|
82
77
|
return msvc_base_path, version
|
|
83
78
|
|
|
84
79
|
|
|
85
80
|
def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
|
|
86
|
-
vswhere_path = find_in_program_files(
|
|
87
|
-
r"Microsoft Visual Studio\Installer\vswhere.exe"
|
|
88
|
-
)
|
|
81
|
+
vswhere_path = find_in_program_files(r"Microsoft Visual Studio\Installer\vswhere.exe")
|
|
89
82
|
if vswhere_path is None:
|
|
90
83
|
return None, None
|
|
91
84
|
|
|
@@ -111,9 +104,7 @@ def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
|
|
|
111
104
|
if not msvc_base_path.exists():
|
|
112
105
|
return None, None
|
|
113
106
|
|
|
114
|
-
version = max_version(
|
|
115
|
-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
|
|
116
|
-
)
|
|
107
|
+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
|
|
117
108
|
if version is None:
|
|
118
109
|
return None, None
|
|
119
110
|
|
|
@@ -132,9 +123,7 @@ def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
|
|
|
132
123
|
if not msvc_base_path.exists():
|
|
133
124
|
continue
|
|
134
125
|
|
|
135
|
-
version = max_version(
|
|
136
|
-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
|
|
137
|
-
)
|
|
126
|
+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
|
|
138
127
|
if version is None:
|
|
139
128
|
continue
|
|
140
129
|
|
|
@@ -153,9 +142,7 @@ def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
|
|
|
153
142
|
paths = sorted(paths)[::-1]
|
|
154
143
|
for msvc_base_path in paths:
|
|
155
144
|
msvc_base_path = Path(msvc_base_path)
|
|
156
|
-
version = max_version(
|
|
157
|
-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
|
|
158
|
-
)
|
|
145
|
+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
|
|
159
146
|
if version is None:
|
|
160
147
|
continue
|
|
161
148
|
return msvc_base_path, version
|
|
@@ -188,13 +175,10 @@ def find_msvc(env_only: bool) -> tuple[Optional[str], list[str], list[str]]:
|
|
|
188
175
|
|
|
189
176
|
|
|
190
177
|
def check_winsdk(winsdk_base_path: Path, version: str) -> bool:
|
|
191
|
-
return all(
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
|
|
196
|
-
]
|
|
197
|
-
)
|
|
178
|
+
return all(x.exists() for x in [
|
|
179
|
+
winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
|
|
180
|
+
winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
|
|
181
|
+
])
|
|
198
182
|
|
|
199
183
|
|
|
200
184
|
def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
|
|
@@ -205,18 +189,16 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
|
|
|
205
189
|
|
|
206
190
|
version = os.getenv("WindowsSDKVersion")
|
|
207
191
|
if version is None:
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
192
|
+
version = os.getenv("WindowsSDKVer")
|
|
193
|
+
if version is None:
|
|
194
|
+
warnings.warn(f"Environment variable WindowsSdkDir = {winsdk_base_path}, "
|
|
195
|
+
"but WindowsSDKVersion (or WindowsSDKVer) is not set.")
|
|
212
196
|
return None, None
|
|
213
197
|
version = version.rstrip("\\")
|
|
214
198
|
if not check_winsdk(winsdk_base_path, version):
|
|
215
|
-
warnings.warn(
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
"but this Windows SDK installation is incomplete."
|
|
219
|
-
)
|
|
199
|
+
warnings.warn(f"Environment variables WindowsSdkDir = {winsdk_base_path}, "
|
|
200
|
+
f"WindowsSDKVersion (or WindowsSDKVer) = {version} are set, "
|
|
201
|
+
"but this Windows SDK installation is incomplete.")
|
|
220
202
|
return None, None
|
|
221
203
|
|
|
222
204
|
return winsdk_base_path, version
|
|
@@ -225,9 +207,7 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
|
|
|
225
207
|
def find_winsdk_registry() -> tuple[Optional[Path], Optional[str]]:
|
|
226
208
|
try:
|
|
227
209
|
reg = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
|
|
228
|
-
key = winreg.OpenKeyEx(
|
|
229
|
-
reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0"
|
|
230
|
-
)
|
|
210
|
+
key = winreg.OpenKeyEx(reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0")
|
|
231
211
|
folder = winreg.QueryValueEx(key, "InstallationFolder")[0]
|
|
232
212
|
winreg.CloseKey(key)
|
|
233
213
|
except OSError:
|
|
@@ -294,9 +274,7 @@ def find_winsdk(env_only: bool) -> tuple[list[str], list[str]]:
|
|
|
294
274
|
|
|
295
275
|
|
|
296
276
|
@functools.lru_cache
|
|
297
|
-
def find_msvc_winsdk(
|
|
298
|
-
env_only: bool = False,
|
|
299
|
-
) -> tuple[Optional[str], list[str], list[str]]:
|
|
277
|
+
def find_msvc_winsdk(env_only: bool = False, ) -> tuple[Optional[str], list[str], list[str]]:
|
|
300
278
|
msvc_bin_path, msvc_inc_dirs, msvc_lib_dirs = find_msvc(env_only)
|
|
301
279
|
winsdk_inc_dirs, winsdk_lib_dirs = find_winsdk(env_only)
|
|
302
280
|
return (
|
|
@@ -312,9 +290,9 @@ def find_python() -> list[str]:
|
|
|
312
290
|
if sysconfig.get_config_var("Py_GIL_DISABLED"):
|
|
313
291
|
version += "t"
|
|
314
292
|
for python_base_path in [
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
293
|
+
sys.exec_prefix,
|
|
294
|
+
sys.base_exec_prefix,
|
|
295
|
+
os.path.dirname(sys.executable),
|
|
318
296
|
]:
|
|
319
297
|
python_lib_dir = Path(python_base_path) / "libs"
|
|
320
298
|
if (python_lib_dir / f"python{version}.lib").exists():
|
|
@@ -326,14 +304,11 @@ def find_python() -> list[str]:
|
|
|
326
304
|
|
|
327
305
|
def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list[str]]:
|
|
328
306
|
# pip
|
|
329
|
-
if all(
|
|
330
|
-
x.exists()
|
|
331
|
-
for x in [
|
|
307
|
+
if all(x.exists() for x in [
|
|
332
308
|
base_path / "cuda_nvcc" / "bin" / "ptxas.exe",
|
|
333
309
|
base_path / "cuda_runtime" / "include" / "cuda.h",
|
|
334
310
|
base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib",
|
|
335
|
-
|
|
336
|
-
):
|
|
311
|
+
]):
|
|
337
312
|
return (
|
|
338
313
|
str(base_path / "cuda_nvcc" / "bin"),
|
|
339
314
|
[str(base_path / "cuda_runtime" / "include")],
|
|
@@ -341,14 +316,11 @@ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list
|
|
|
341
316
|
)
|
|
342
317
|
|
|
343
318
|
# conda
|
|
344
|
-
if all(
|
|
345
|
-
x.exists()
|
|
346
|
-
for x in [
|
|
319
|
+
if all(x.exists() for x in [
|
|
347
320
|
base_path / "bin" / "ptxas.exe",
|
|
348
321
|
base_path / "include" / "cuda.h",
|
|
349
322
|
base_path / "lib" / "cuda.lib",
|
|
350
|
-
|
|
351
|
-
):
|
|
323
|
+
]):
|
|
352
324
|
return (
|
|
353
325
|
str(base_path / "bin"),
|
|
354
326
|
[str(base_path / "include")],
|
|
@@ -356,14 +328,11 @@ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list
|
|
|
356
328
|
)
|
|
357
329
|
|
|
358
330
|
# bundled or system-wide
|
|
359
|
-
if all(
|
|
360
|
-
x.exists()
|
|
361
|
-
for x in [
|
|
331
|
+
if all(x.exists() for x in [
|
|
362
332
|
base_path / "bin" / "ptxas.exe",
|
|
363
333
|
base_path / "include" / "cuda.h",
|
|
364
334
|
base_path / "lib" / "x64" / "cuda.lib",
|
|
365
|
-
|
|
366
|
-
):
|
|
335
|
+
]):
|
|
367
336
|
return (
|
|
368
337
|
str(base_path / "bin"),
|
|
369
338
|
[str(base_path / "include")],
|
|
@@ -380,9 +349,7 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
|
|
|
380
349
|
continue
|
|
381
350
|
|
|
382
351
|
cuda_base_path = Path(cuda_base_path)
|
|
383
|
-
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
|
|
384
|
-
cuda_base_path
|
|
385
|
-
)
|
|
352
|
+
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
|
|
386
353
|
if cuda_bin_path:
|
|
387
354
|
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
|
|
388
355
|
|
|
@@ -390,9 +357,7 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
|
|
|
390
357
|
|
|
391
358
|
|
|
392
359
|
def find_cuda_bundled() -> tuple[Optional[str], list[str], list[str]]:
|
|
393
|
-
cuda_base_path = (
|
|
394
|
-
Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia"
|
|
395
|
-
)
|
|
360
|
+
cuda_base_path = (Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia")
|
|
396
361
|
return check_and_find_cuda(cuda_base_path)
|
|
397
362
|
|
|
398
363
|
|
|
@@ -416,9 +381,7 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
|
|
|
416
381
|
paths = sorted(paths)[::-1]
|
|
417
382
|
for cuda_base_path in paths:
|
|
418
383
|
cuda_base_path = Path(cuda_base_path)
|
|
419
|
-
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
|
|
420
|
-
cuda_base_path
|
|
421
|
-
)
|
|
384
|
+
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
|
|
422
385
|
if cuda_bin_path:
|
|
423
386
|
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
|
|
424
387
|
|
|
@@ -428,11 +391,11 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
|
|
|
428
391
|
@functools.lru_cache
|
|
429
392
|
def find_cuda() -> tuple[Optional[str], list[str], list[str]]:
|
|
430
393
|
for f in [
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
394
|
+
find_cuda_env,
|
|
395
|
+
find_cuda_bundled,
|
|
396
|
+
find_cuda_pip,
|
|
397
|
+
find_cuda_conda,
|
|
398
|
+
find_cuda_hardcoded,
|
|
436
399
|
]:
|
|
437
400
|
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = f()
|
|
438
401
|
if cuda_bin_path:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: triton-windows
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.5.0.post21
|
|
4
4
|
Summary: A language and compiler for custom Deep Learning operations
|
|
5
5
|
Home-page: https://github.com/woct0rdho/triton-windows
|
|
6
6
|
Author: Philippe Tillet, Dian Wu
|
|
@@ -10,14 +10,13 @@ Classifier: Development Status :: 4 - Beta
|
|
|
10
10
|
Classifier: Intended Audience :: Developers
|
|
11
11
|
Classifier: Topic :: Software Development :: Build Tools
|
|
12
12
|
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
-
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
18
|
+
Requires-Python: >=3.10,<3.15
|
|
19
19
|
License-File: LICENSE
|
|
20
|
-
Requires-Dist: setuptools>=40.8.0
|
|
21
20
|
Requires-Dist: importlib-metadata; python_version < "3.10"
|
|
22
21
|
Provides-Extra: build
|
|
23
22
|
Requires-Dist: cmake<4.0,>=3.20; extra == "build"
|