triton-windows 3.4.0.post20__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
Binary file
Binary file
triton/tools/compile.py CHANGED
@@ -3,12 +3,29 @@ import hashlib
3
3
  import importlib.util
4
4
  import sys
5
5
  from argparse import ArgumentParser
6
+ from dataclasses import dataclass
6
7
  from pathlib import Path
7
8
  from typing import List
8
9
 
9
10
  import triton
10
11
  import triton.backends
11
- from triton.backends.nvidia.driver import ty_to_cpp
12
+
13
+
14
+ @dataclass
15
+ class CompileArgs:
16
+ '''
17
+ A class to contain arguments from command-line parser.
18
+ '''
19
+ path: str = ''
20
+ kernel_name: str = ''
21
+ signature: str = ''
22
+ grid: str = ''
23
+ target: str | None = None
24
+ num_warps: int = 1
25
+ num_stages: int = 3
26
+ out_name: str | None = None
27
+ out_path: Path | None = None
28
+
12
29
 
13
30
  desc = """
14
31
  Triton ahead-of-time compiler:
@@ -36,14 +53,18 @@ NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed
36
53
  used to run this `compile.py` script
37
54
  """
38
55
 
39
- if __name__ == "__main__":
40
56
 
57
+ def main():
41
58
  # command-line arguments
42
59
  parser = ArgumentParser(description=desc)
43
60
  parser.add_argument("path",
44
61
  help="Path to Python source containing desired kernel in its scope. File will be executed.")
45
62
  parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
46
63
  required=True)
64
+ parser.add_argument(
65
+ "--target", "-t", type=str, default=None,
66
+ help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
67
+ "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
47
68
  parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
48
69
  parser.add_argument("--num-stages", "-ns", type=int, default=3,
49
70
  help="Number of stages (meta-parameter of the kernel)")
@@ -51,8 +72,12 @@ if __name__ == "__main__":
51
72
  parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
52
73
  parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
53
74
  parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
54
- args = parser.parse_args()
75
+ cli_args = parser.parse_args()
76
+ args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
77
+ compile_kernel(args)
55
78
 
79
+
80
+ def compile_kernel(args: CompileArgs):
56
81
  out_name = args.out_name if args.out_name else args.kernel_name
57
82
  out_path = args.out_path if args.out_path else Path(out_name)
58
83
 
@@ -108,10 +133,18 @@ if __name__ == "__main__":
108
133
  assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
109
134
  attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
110
135
  src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
111
- opts = {"num_warps": args.num_warps, "num_stages": args.num_stages}
112
- ccinfo = triton.compile(src, options=opts)
113
- if ccinfo.metadata.global_scratch_size > 0:
136
+
137
+ target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
138
+ if args.target else triton.runtime.driver.active.get_current_target()
139
+ backend = triton.compiler.make_backend(target)
140
+ kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
141
+ options = backend.parse_options(kwargs)
142
+ ccinfo = triton.compile(src, target=target, options=options.__dict__)
143
+
144
+ if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
114
145
  raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
146
+ if ccinfo.metadata.profile_scratch_size > 0:
147
+ raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
115
148
 
116
149
  arg_names = []
117
150
  arg_types = []
@@ -136,8 +169,12 @@ if __name__ == "__main__":
136
169
  if hints.get((i, ), None) == 16:
137
170
  suffix += 'd'
138
171
  func_name = '_'.join([out_name, sig_hash, suffix])
139
- asm = ccinfo.asm["cubin"] # store binary data once
172
+ asm = ccinfo.asm[backend.binary_ext] # store binary data once
173
+
140
174
  hex_ = str(binascii.hexlify(asm))[2:-1]
175
+
176
+ ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
177
+
141
178
  params = {
142
179
  "kernel_name": func_name,
143
180
  "triton_kernel_name": args.kernel_name,
@@ -145,18 +182,29 @@ if __name__ == "__main__":
145
182
  "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
146
183
  "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
147
184
  "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
148
- "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]),
149
- "num_args": len(arg_names_not_1) + 1,
185
+ "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
186
+ "num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
150
187
  "kernel_docstring": doc_string,
151
188
  "shared": ccinfo.metadata.shared,
152
189
  "num_warps": args.num_warps,
153
- "algo_info": '_'.join([const_sig, meta_sig]),
190
+ "algo_info": "_".join([const_sig, meta_sig]),
154
191
  "gridX": grid[0],
155
192
  "gridY": grid[1],
156
193
  "gridZ": grid[2],
157
194
  "_placeholder": "",
158
195
  }
159
- for ext in ['h', 'c']:
160
- template_path = Path(__file__).parent / "extra" / "cuda" / f"compile.{ext}"
161
- with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp:
162
- fp.write(Path(template_path).read_text().format(**params))
196
+ output_files = []
197
+ backend_name = target.backend
198
+ template_dir = Path(__file__).parent / "extra" / backend_name
199
+ for template_path in template_dir.glob('compile.*'):
200
+ ext = template_path.suffix
201
+ output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
202
+ with output_file.open("w") as fp:
203
+ fp.write(template_path.read_text().format(**params))
204
+ output_files.append(output_file)
205
+
206
+ return func_name, output_files
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
@@ -61,6 +61,7 @@ CUresult {kernel_name}(CUstream stream, {signature}) {{
61
61
  unsigned int gY = {gridY};
62
62
  unsigned int gZ = {gridZ};
63
63
  CUdeviceptr global_scratch = 0;
64
+ CUdeviceptr profile_scratch = 0;
64
65
  void *args[{num_args}] = {{ {arg_pointers} }};
65
66
  // TODO: shared memory
66
67
  if(gX * gY * gZ > 0)
@@ -0,0 +1,66 @@
1
+ // SPDX-License-Identifier: MIT
2
+ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ /* clang-format off */
5
+ #include <stdio.h>
6
+ #include <stdint.h>
7
+ #include <inttypes.h>
8
+ #include <string.h>
9
+ #include <hip/hip_runtime.h>
10
+
11
+ // helpers to check for hip errors
12
+ #define HIP_CHECK(ans) {{\
13
+ gpuAssert((ans), __FILE__, __LINE__);\
14
+ }}\
15
+
16
+ static inline void gpuAssert(hipError_t code, const char *file, int line) {{
17
+ if (code != hipSuccess) {{
18
+ const char *prefix = "Triton Error [HIP]: ";
19
+ const char *str;
20
+ hipDrvGetErrorString(code, &str);
21
+ char err[1024] = {{0}};
22
+ strcat(err, prefix);
23
+ strcat(err, str);
24
+ printf("%s\\n", err);
25
+ exit(code);
26
+ }}
27
+ }}
28
+
29
+ // globals
30
+ #define HSACO_NAME {kernel_name}_hsaco
31
+ hipModule_t {kernel_name}_mod = nullptr;
32
+ hipFunction_t {kernel_name}_func = nullptr;
33
+ unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
34
+
35
+
36
+ void unload_{kernel_name}(void) {{
37
+ HIP_CHECK(hipModuleUnload({kernel_name}_mod));
38
+ }}
39
+
40
+
41
+ void load_{kernel_name}() {{
42
+ int dev = 0;
43
+ void *bin = (void *)&HSACO_NAME;
44
+ int shared = {shared};
45
+ HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
46
+ HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
47
+ }}
48
+
49
+ /*
50
+ {kernel_docstring}
51
+ */
52
+ hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
53
+ if ({kernel_name}_func == nullptr)
54
+ load_{kernel_name}();
55
+ unsigned int gX = {gridX};
56
+ unsigned int gY = {gridY};
57
+ unsigned int gZ = {gridZ};
58
+ hipDeviceptr_t global_scratch = 0;
59
+ hipDeviceptr_t profile_scratch = 0;
60
+ void *args[{num_args}] = {{ {arg_pointers} }};
61
+ // TODO: shared memory
62
+ if(gX * gY * gZ > 0)
63
+ return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
64
+ else
65
+ return hipErrorInvalidValue;
66
+ }}
@@ -0,0 +1,13 @@
1
+ // SPDX-License-Identifier: MIT
2
+ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ #pragma once
5
+
6
+ #include <hip/hip_runtime.h>
7
+ #include <inttypes.h>
8
+ #include <stdint.h>
9
+ #include <stdio.h>
10
+
11
+ void unload_{kernel_name}(void);
12
+ void load_{kernel_name}(void);
13
+ hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});
@@ -0,0 +1,92 @@
1
+ import triton
2
+ import triton.language as tl
3
+ from triton.tools.tensor_descriptor import TensorDescriptor
4
+
5
+ # fmt: off
6
+
7
+
8
+ def create_ragged_descriptor(T, block_shape, ragged_dim=0):
9
+ """
10
+ Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
11
+ which behaves like a concatenation (along the first axis) of subarrays
12
+ of potentially unequal size.
13
+
14
+ The load_ragged and store_ragged device functions can be used to read
15
+ and write from subarrays T[batch_offset : batch_offset + batch_size]
16
+ with hardware bounds-checking preventing any sort of leakage outside
17
+ the subarray.
18
+ """
19
+
20
+ block_shape = list(block_shape)
21
+ tensor_shape = list(T.shape)
22
+ rank = len(tensor_shape)
23
+
24
+ if ragged_dim < 0:
25
+ ragged_dim += rank
26
+
27
+ assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
28
+ assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
29
+
30
+ assert len(block_shape) == rank, "block shape must have same length as tensor shape"
31
+
32
+ max_int = 0x7fff0000
33
+ billion = 0x40000000 # == 2**30
34
+
35
+ assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
36
+ tensor_shape[ragged_dim] = billion
37
+ ragged_stride = T.stride(ragged_dim)
38
+
39
+ # we prepend an extra two dimensions and rely on the fact that pointers
40
+ # have 64-bit wraparound semantics:
41
+ tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
42
+ tma_shape = [max_int, max_int] + tensor_shape
43
+ box_shape = [1, 1] + block_shape
44
+
45
+ return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
46
+
47
+
48
+ @triton.jit
49
+ def to_ragged_indices(batch_offset, batch_size, row):
50
+ """
51
+ Helper function for load_ragged and store_ragged.
52
+ """
53
+
54
+ billion = 0x40000000 # == 2**30
55
+ x = billion - batch_size + row
56
+ y = batch_offset + batch_size
57
+
58
+ return billion, y, x
59
+
60
+
61
+ @triton.jit
62
+ def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
63
+ """
64
+ Read from a subarray T[batch_offset : batch_offset + batch_size] with
65
+ hardware bounds-checking, where reading outside the subarray gives zeros.
66
+
67
+ Coords should be an appropriately-sized list of integers, just like in
68
+ TMA.load().
69
+ """
70
+
71
+ tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
72
+
73
+ c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
74
+ data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
75
+ data = tl.reshape(data, data.shape[2:])
76
+ return data
77
+
78
+
79
+ @triton.jit
80
+ def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
81
+ """
82
+ Write to a subarray T[batch_offset : batch_offset + batch_size] with
83
+ hardware bounds-checking, where writes outside the subarray are masked
84
+ correctly.
85
+
86
+ Coords should be an appropriately-sized list of integers, just like in
87
+ TMA.store().
88
+ """
89
+
90
+ c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
91
+ data = tl.reshape(data, [1, 1] + data.shape)
92
+ TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
@@ -9,6 +9,7 @@ class TensorDescriptor:
9
9
  shape: List[int]
10
10
  strides: List[int]
11
11
  block_shape: List[int]
12
+ padding: str = "zero"
12
13
 
13
14
  def __post_init__(self):
14
15
  rank = len(self.shape)
@@ -17,20 +18,17 @@ class TensorDescriptor:
17
18
  assert rank > 0, "rank must not be zero"
18
19
  assert rank <= 5, "rank cannot be more than 5"
19
20
  ty = type(self.base)
20
- type_name = f"{ty.__module__}.{ty.__name__}"
21
- if type_name not in ("torch.FakeTensor", "torch.FunctionalTensor"):
21
+ if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
22
22
  assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
23
23
  validate_block_shape(self.block_shape)
24
24
  elem_bytes = self.base.dtype.itemsize
25
25
  for stride in self.strides[:-1]:
26
26
  assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
27
27
  assert self.strides[-1] == 1, "Last dimension must be contiguous"
28
+ assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
29
+ if self.padding == "nan":
30
+ assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
28
31
 
29
32
  @staticmethod
30
- def from_tensor(tensor: Any, block_shape: List[int]):
31
- return TensorDescriptor(
32
- tensor,
33
- tensor.shape,
34
- tensor.stride(),
35
- block_shape,
36
- )
33
+ def from_tensor(tensor: Any, block_shape: List[int], padding="zero"):
34
+ return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)
triton/windows_utils.py CHANGED
@@ -54,14 +54,11 @@ def max_version(
54
54
 
55
55
 
56
56
  def check_msvc(msvc_base_path: Path, version: str) -> bool:
57
- return all(
58
- x.exists()
59
- for x in [
60
- msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
61
- msvc_base_path / version / "include" / "vcruntime.h",
62
- msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
63
- ]
64
- )
57
+ return all(x.exists() for x in [
58
+ msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
59
+ msvc_base_path / version / "include" / "vcruntime.h",
60
+ msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
61
+ ])
65
62
 
66
63
 
67
64
  def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
@@ -72,20 +69,16 @@ def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
72
69
 
73
70
  version = os.getenv("VCToolsVersion")
74
71
  if not check_msvc(msvc_base_path, version):
75
- warnings.warn(
76
- f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
77
- f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
78
- "but this MSVC installation is incomplete."
79
- )
72
+ warnings.warn(f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
73
+ f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
74
+ "but this MSVC installation is incomplete.")
80
75
  return None, None
81
76
 
82
77
  return msvc_base_path, version
83
78
 
84
79
 
85
80
  def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
86
- vswhere_path = find_in_program_files(
87
- r"Microsoft Visual Studio\Installer\vswhere.exe"
88
- )
81
+ vswhere_path = find_in_program_files(r"Microsoft Visual Studio\Installer\vswhere.exe")
89
82
  if vswhere_path is None:
90
83
  return None, None
91
84
 
@@ -111,9 +104,7 @@ def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
111
104
  if not msvc_base_path.exists():
112
105
  return None, None
113
106
 
114
- version = max_version(
115
- os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
116
- )
107
+ version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
117
108
  if version is None:
118
109
  return None, None
119
110
 
@@ -132,9 +123,7 @@ def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
132
123
  if not msvc_base_path.exists():
133
124
  continue
134
125
 
135
- version = max_version(
136
- os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
137
- )
126
+ version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
138
127
  if version is None:
139
128
  continue
140
129
 
@@ -153,9 +142,7 @@ def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
153
142
  paths = sorted(paths)[::-1]
154
143
  for msvc_base_path in paths:
155
144
  msvc_base_path = Path(msvc_base_path)
156
- version = max_version(
157
- os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
158
- )
145
+ version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
159
146
  if version is None:
160
147
  continue
161
148
  return msvc_base_path, version
@@ -188,13 +175,10 @@ def find_msvc(env_only: bool) -> tuple[Optional[str], list[str], list[str]]:
188
175
 
189
176
 
190
177
  def check_winsdk(winsdk_base_path: Path, version: str) -> bool:
191
- return all(
192
- x.exists()
193
- for x in [
194
- winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
195
- winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
196
- ]
197
- )
178
+ return all(x.exists() for x in [
179
+ winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
180
+ winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
181
+ ])
198
182
 
199
183
 
200
184
  def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
@@ -205,18 +189,16 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
205
189
 
206
190
  version = os.getenv("WindowsSDKVersion")
207
191
  if version is None:
208
- warnings.warn(
209
- f"Environment variable WindowsSdkDir = {os.getenv('WindowsSdkDir')}, "
210
- "but WindowsSDKVersion is not set."
211
- )
192
+ version = os.getenv("WindowsSDKVer")
193
+ if version is None:
194
+ warnings.warn(f"Environment variable WindowsSdkDir = {winsdk_base_path}, "
195
+ "but WindowsSDKVersion (or WindowsSDKVer) is not set.")
212
196
  return None, None
213
197
  version = version.rstrip("\\")
214
198
  if not check_winsdk(winsdk_base_path, version):
215
- warnings.warn(
216
- f"Environment variables WindowsSdkDir = {os.getenv('WindowsSdkDir')}, "
217
- f"WindowsSDKVersion = {os.getenv('WindowsSDKVersion')} are set, "
218
- "but this Windows SDK installation is incomplete."
219
- )
199
+ warnings.warn(f"Environment variables WindowsSdkDir = {winsdk_base_path}, "
200
+ f"WindowsSDKVersion (or WindowsSDKVer) = {version} are set, "
201
+ "but this Windows SDK installation is incomplete.")
220
202
  return None, None
221
203
 
222
204
  return winsdk_base_path, version
@@ -225,9 +207,7 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
225
207
  def find_winsdk_registry() -> tuple[Optional[Path], Optional[str]]:
226
208
  try:
227
209
  reg = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
228
- key = winreg.OpenKeyEx(
229
- reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0"
230
- )
210
+ key = winreg.OpenKeyEx(reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0")
231
211
  folder = winreg.QueryValueEx(key, "InstallationFolder")[0]
232
212
  winreg.CloseKey(key)
233
213
  except OSError:
@@ -294,9 +274,7 @@ def find_winsdk(env_only: bool) -> tuple[list[str], list[str]]:
294
274
 
295
275
 
296
276
  @functools.lru_cache
297
- def find_msvc_winsdk(
298
- env_only: bool = False,
299
- ) -> tuple[Optional[str], list[str], list[str]]:
277
+ def find_msvc_winsdk(env_only: bool = False, ) -> tuple[Optional[str], list[str], list[str]]:
300
278
  msvc_bin_path, msvc_inc_dirs, msvc_lib_dirs = find_msvc(env_only)
301
279
  winsdk_inc_dirs, winsdk_lib_dirs = find_winsdk(env_only)
302
280
  return (
@@ -312,9 +290,9 @@ def find_python() -> list[str]:
312
290
  if sysconfig.get_config_var("Py_GIL_DISABLED"):
313
291
  version += "t"
314
292
  for python_base_path in [
315
- sys.exec_prefix,
316
- sys.base_exec_prefix,
317
- os.path.dirname(sys.executable),
293
+ sys.exec_prefix,
294
+ sys.base_exec_prefix,
295
+ os.path.dirname(sys.executable),
318
296
  ]:
319
297
  python_lib_dir = Path(python_base_path) / "libs"
320
298
  if (python_lib_dir / f"python{version}.lib").exists():
@@ -326,14 +304,11 @@ def find_python() -> list[str]:
326
304
 
327
305
  def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list[str]]:
328
306
  # pip
329
- if all(
330
- x.exists()
331
- for x in [
307
+ if all(x.exists() for x in [
332
308
  base_path / "cuda_nvcc" / "bin" / "ptxas.exe",
333
309
  base_path / "cuda_runtime" / "include" / "cuda.h",
334
310
  base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib",
335
- ]
336
- ):
311
+ ]):
337
312
  return (
338
313
  str(base_path / "cuda_nvcc" / "bin"),
339
314
  [str(base_path / "cuda_runtime" / "include")],
@@ -341,14 +316,11 @@ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list
341
316
  )
342
317
 
343
318
  # conda
344
- if all(
345
- x.exists()
346
- for x in [
319
+ if all(x.exists() for x in [
347
320
  base_path / "bin" / "ptxas.exe",
348
321
  base_path / "include" / "cuda.h",
349
322
  base_path / "lib" / "cuda.lib",
350
- ]
351
- ):
323
+ ]):
352
324
  return (
353
325
  str(base_path / "bin"),
354
326
  [str(base_path / "include")],
@@ -356,14 +328,11 @@ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list
356
328
  )
357
329
 
358
330
  # bundled or system-wide
359
- if all(
360
- x.exists()
361
- for x in [
331
+ if all(x.exists() for x in [
362
332
  base_path / "bin" / "ptxas.exe",
363
333
  base_path / "include" / "cuda.h",
364
334
  base_path / "lib" / "x64" / "cuda.lib",
365
- ]
366
- ):
335
+ ]):
367
336
  return (
368
337
  str(base_path / "bin"),
369
338
  [str(base_path / "include")],
@@ -380,9 +349,7 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
380
349
  continue
381
350
 
382
351
  cuda_base_path = Path(cuda_base_path)
383
- cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
384
- cuda_base_path
385
- )
352
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
386
353
  if cuda_bin_path:
387
354
  return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
388
355
 
@@ -390,9 +357,7 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
390
357
 
391
358
 
392
359
  def find_cuda_bundled() -> tuple[Optional[str], list[str], list[str]]:
393
- cuda_base_path = (
394
- Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia"
395
- )
360
+ cuda_base_path = (Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia")
396
361
  return check_and_find_cuda(cuda_base_path)
397
362
 
398
363
 
@@ -416,9 +381,7 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
416
381
  paths = sorted(paths)[::-1]
417
382
  for cuda_base_path in paths:
418
383
  cuda_base_path = Path(cuda_base_path)
419
- cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
420
- cuda_base_path
421
- )
384
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
422
385
  if cuda_bin_path:
423
386
  return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
424
387
 
@@ -428,11 +391,11 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
428
391
  @functools.lru_cache
429
392
  def find_cuda() -> tuple[Optional[str], list[str], list[str]]:
430
393
  for f in [
431
- find_cuda_env,
432
- find_cuda_bundled,
433
- find_cuda_pip,
434
- find_cuda_conda,
435
- find_cuda_hardcoded,
394
+ find_cuda_env,
395
+ find_cuda_bundled,
396
+ find_cuda_pip,
397
+ find_cuda_conda,
398
+ find_cuda_hardcoded,
436
399
  ]:
437
400
  cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = f()
438
401
  if cuda_bin_path:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: triton-windows
3
- Version: 3.4.0.post20
3
+ Version: 3.5.0.post21
4
4
  Summary: A language and compiler for custom Deep Learning operations
5
5
  Home-page: https://github.com/woct0rdho/triton-windows
6
6
  Author: Philippe Tillet, Dian Wu
@@ -10,14 +10,13 @@ Classifier: Development Status :: 4 - Beta
10
10
  Classifier: Intended Audience :: Developers
11
11
  Classifier: Topic :: Software Development :: Build Tools
12
12
  Classifier: License :: OSI Approved :: MIT License
13
- Classifier: Programming Language :: Python :: 3.9
14
13
  Classifier: Programming Language :: Python :: 3.10
15
14
  Classifier: Programming Language :: Python :: 3.11
16
15
  Classifier: Programming Language :: Python :: 3.12
17
16
  Classifier: Programming Language :: Python :: 3.13
18
- Requires-Python: >=3.9,<3.14
17
+ Classifier: Programming Language :: Python :: 3.14
18
+ Requires-Python: >=3.10,<3.15
19
19
  License-File: LICENSE
20
- Requires-Dist: setuptools>=40.8.0
21
20
  Requires-Dist: importlib-metadata; python_version < "3.10"
22
21
  Provides-Extra: build
23
22
  Requires-Dist: cmake<4.0,>=3.20; extra == "build"