triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
Binary file
triton/__init__.py ADDED
@@ -0,0 +1,82 @@
1
+ """isort:skip_file"""
2
+ __version__ = '3.5.1'
3
+
4
+ # ---------------------------------------
5
+ # Note: import order is significant here.
6
+
7
+ # submodules
8
+ from .runtime import (
9
+ autotune,
10
+ Config,
11
+ heuristics,
12
+ JITFunction,
13
+ KernelInterface,
14
+ reinterpret,
15
+ TensorWrapper,
16
+ OutOfResources,
17
+ InterpreterError,
18
+ MockTensor,
19
+ )
20
+ from .runtime.jit import constexpr_function, jit
21
+ from .runtime._async_compile import AsyncCompileMode, FutureKernel
22
+ from .compiler import compile, CompilationError
23
+ from .errors import TritonError
24
+ from .runtime._allocation import set_allocator
25
+
26
+ from . import language
27
+ from . import testing
28
+ from . import tools
29
+
30
+ must_use_result = language.core.must_use_result
31
+
32
+ __all__ = [
33
+ "AsyncCompileMode",
34
+ "autotune",
35
+ "cdiv",
36
+ "CompilationError",
37
+ "compile",
38
+ "Config",
39
+ "constexpr_function",
40
+ "FutureKernel",
41
+ "heuristics",
42
+ "InterpreterError",
43
+ "jit",
44
+ "JITFunction",
45
+ "KernelInterface",
46
+ "language",
47
+ "MockTensor",
48
+ "must_use_result",
49
+ "next_power_of_2",
50
+ "OutOfResources",
51
+ "reinterpret",
52
+ "runtime",
53
+ "set_allocator",
54
+ "TensorWrapper",
55
+ "TritonError",
56
+ "testing",
57
+ "tools",
58
+ ]
59
+
60
+ # -------------------------------------
61
+ # misc. utilities that don't fit well
62
+ # into any specific module
63
+ # -------------------------------------
64
+
65
+
66
+ @constexpr_function
67
+ def cdiv(x: int, y: int):
68
+ return (x + y - 1) // y
69
+
70
+
71
+ @constexpr_function
72
+ def next_power_of_2(n: int):
73
+ """Return the smallest power of 2 greater than or equal to n"""
74
+ n -= 1
75
+ n |= n >> 1
76
+ n |= n >> 2
77
+ n |= n >> 4
78
+ n |= n >> 8
79
+ n |= n >> 16
80
+ n |= n >> 32
81
+ n += 1
82
+ return n
triton/_filecheck.py ADDED
@@ -0,0 +1,97 @@
1
+ import functools
2
+ import os
3
+ import inspect
4
+ import subprocess
5
+ import tempfile
6
+
7
+ import triton
8
+ from triton.compiler import ASTSource, make_backend
9
+ from triton.backends.compiler import GPUTarget
10
+ from triton.experimental.gluon._runtime import GluonASTSource
11
+ from triton.runtime.jit import create_function_from_signature
12
+ from triton._C.libtriton import ir
13
+
14
+ # ===-----------------------------------------------------------------------===#
15
+ # filecheck_test
16
+ # ===-----------------------------------------------------------------------===#
17
+
18
+ # Stub target for testing the frontend.
19
+ stub_target = GPUTarget("cuda", 100, 32)
20
+
21
+ triton_dir = os.path.dirname(__file__)
22
+ filecheck_path = os.path.join(triton_dir, "FileCheck")
23
+
24
+
25
+ class MatchError(ValueError):
26
+
27
+ def __init__(self, message, module_str):
28
+ super().__init__(message)
29
+ self.module_str = module_str
30
+
31
+ def __str__(self):
32
+ return f"{super().__str__()}\n{self.module_str}"
33
+
34
+
35
+ def run_filecheck(name, module_str, check_template):
36
+ with tempfile.TemporaryDirectory() as tempdir:
37
+ temp_module = os.path.join(tempdir, "module")
38
+ with open(temp_module, "w") as temp:
39
+ temp.write(module_str)
40
+
41
+ temp_expected = os.path.join(tempdir, "expected")
42
+ with open(temp_expected, "w") as temp:
43
+ temp.write(check_template)
44
+
45
+ try:
46
+ subprocess.check_output(
47
+ [filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
48
+ stderr=subprocess.STDOUT)
49
+ except subprocess.CalledProcessError as error:
50
+ decoded = error.output.decode('unicode_escape')
51
+ raise ValueError(decoded)
52
+
53
+
54
+ def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
55
+ if "sanitize_overflow" not in kwargs:
56
+ kwargs = dict(kwargs)
57
+ kwargs["sanitize_overflow"] = False
58
+ backend = make_backend(target)
59
+ binder = create_function_from_signature(
60
+ kernel_fn.signature,
61
+ kernel_fn.params,
62
+ backend,
63
+ )
64
+
65
+ bound_args, specialization, options = binder(*args, **kwargs)
66
+ options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
67
+ source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
68
+ src = source_cls(kernel_fn, signature, constexprs, attrs)
69
+
70
+ context = ir.context()
71
+ ir.load_dialects(context)
72
+ backend.load_dialects(context)
73
+
74
+ codegen_fns = backend.get_codegen_implementation(options)
75
+ module_map = backend.get_module_map()
76
+ module = src.make_ir(target, options, codegen_fns, module_map, context)
77
+ assert module.verify()
78
+ return module
79
+
80
+
81
+ def run_filecheck_test(kernel_fn):
82
+ assert isinstance(kernel_fn, triton.runtime.JITFunction)
83
+ check_template = inspect.getsource(kernel_fn.fn)
84
+ if check_template is None:
85
+ raise ValueError("kernel function must have a docstring with FileCheck template")
86
+ mlir_module = run_parser(kernel_fn)
87
+
88
+ run_filecheck("placeholder", mlir_module.str_nodebug(), check_template)
89
+
90
+
91
+ def filecheck_test(fn):
92
+
93
+ @functools.wraps(fn)
94
+ def test_fn():
95
+ run_filecheck_test(fn)
96
+
97
+ return test_fn
@@ -0,0 +1,255 @@
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+ from triton import knobs
8
+ from typing import Optional, Set, Union
9
+ import pytest
10
+
11
+ from numpy.random import RandomState
12
+ from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
13
+
14
+ int_dtypes = ['int8', 'int16', 'int32', 'int64']
15
+ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
16
+ integral_dtypes = int_dtypes + uint_dtypes
17
+ float_dtypes = ['float16', 'float32', 'float64']
18
+ float_dtypes_with_bfloat16 = float_dtypes + ['bfloat16']
19
+ dtypes = integral_dtypes + float_dtypes
20
+ dtypes_with_bfloat16 = dtypes + ['bfloat16']
21
+ torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
22
+ torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
23
+ tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})
24
+
25
+
26
+ def is_interpreter():
27
+ return os.environ.get('TRITON_INTERPRET', '0') == '1'
28
+
29
+
30
+ def get_current_target():
31
+ if is_interpreter():
32
+ return None
33
+ return triton.runtime.driver.active.get_current_target()
34
+
35
+
36
+ def is_cuda():
37
+ target = get_current_target()
38
+ return False if target is None else target.backend == "cuda"
39
+
40
+
41
+ def is_ampere_or_newer():
42
+ return is_cuda() and torch.cuda.get_device_capability()[0] >= 8
43
+
44
+
45
+ def is_blackwell():
46
+ return is_cuda() and torch.cuda.get_device_capability()[0] == 10
47
+
48
+
49
+ def is_hopper_or_newer():
50
+ return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
51
+
52
+
53
+ def is_hopper():
54
+ return is_cuda() and torch.cuda.get_device_capability()[0] == 9
55
+
56
+
57
+ def is_hip():
58
+ target = get_current_target()
59
+ return False if target is None else target.backend == "hip"
60
+
61
+
62
+ def is_hip_cdna2():
63
+ target = get_current_target()
64
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx90a'
65
+
66
+
67
+ def is_hip_cdna3():
68
+ target = get_current_target()
69
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx942'
70
+
71
+
72
+ def is_hip_cdna4():
73
+ target = get_current_target()
74
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
75
+
76
+
77
+ def is_hip_gfx11():
78
+ target = get_current_target()
79
+ return target is not None and target.backend == 'hip' and 'gfx11' in target.arch
80
+
81
+
82
+ def is_hip_gfx12():
83
+ target = get_current_target()
84
+ return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
85
+
86
+
87
+ def is_hip_cdna():
88
+ return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
89
+
90
+
91
+ def get_hip_lds_size():
92
+ return 163840 if is_hip_cdna4() else 65536
93
+
94
+
95
+ def is_xpu():
96
+ target = get_current_target()
97
+ return False if target is None else target.backend == "xpu"
98
+
99
+
100
+ def get_arch():
101
+ target = get_current_target()
102
+ return "" if target is None else str(target.arch)
103
+
104
+
105
+ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
106
+ """
107
+ Override `rs` if you're calling this function twice and don't want the same
108
+ result for both calls.
109
+ """
110
+ if isinstance(shape, int):
111
+ shape = (shape, )
112
+ if rs is None:
113
+ rs = RandomState(seed=17)
114
+ if dtype_str in int_dtypes + uint_dtypes:
115
+ iinfo = np.iinfo(getattr(np, dtype_str))
116
+ low = iinfo.min if low is None else max(low, iinfo.min)
117
+ high = iinfo.max if high is None else min(high, iinfo.max)
118
+ dtype = getattr(np, dtype_str)
119
+ x = rs.randint(low, high, shape, dtype=dtype)
120
+ x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out.
121
+ return x
122
+ elif dtype_str and 'float8' in dtype_str:
123
+ x = rs.randint(20, 40, shape, dtype=np.int8)
124
+ return x
125
+ elif dtype_str in float_dtypes:
126
+ return rs.normal(0, 1, shape).astype(dtype_str)
127
+ elif dtype_str == 'bfloat16':
128
+ return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32')
129
+ elif dtype_str in ['bool', 'int1', 'bool_']:
130
+ return rs.normal(0, 1, shape) > 0.0
131
+ else:
132
+ raise RuntimeError(f'Unknown dtype {dtype_str}')
133
+
134
+
135
+ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
136
+ '''
137
+ Note: We need dst_type because the type of x can be different from dst_type.
138
+ For example: x is of type `float32`, dst_type is `bfloat16`.
139
+ If dst_type is None, we infer dst_type from x.
140
+ '''
141
+ t = x.dtype.name
142
+ if t in uint_dtypes:
143
+ signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
144
+ x_signed = x.astype(getattr(np, signed_type_name))
145
+ return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
146
+ else:
147
+ if dst_type and 'float8' in dst_type:
148
+ return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type))
149
+ if t == 'float32' and dst_type == 'bfloat16':
150
+ return torch.tensor(x, device=device).bfloat16()
151
+ return torch.tensor(x, device=device)
152
+
153
+
154
+ def str_to_triton_dtype(x: str) -> tl.dtype:
155
+ return tl.str_to_ty(type_canonicalisation_dict[x], None)
156
+
157
+
158
+ def torch_dtype_name(dtype) -> str:
159
+ if isinstance(dtype, triton.language.dtype):
160
+ return dtype.name
161
+ elif isinstance(dtype, torch.dtype):
162
+ # 'torch.int64' -> 'int64'
163
+ m = re.match(r'^torch\.(\w+)$', str(dtype))
164
+ return m.group(1)
165
+ else:
166
+ raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
167
+
168
+
169
+ def to_numpy(x):
170
+ if isinstance(x, TensorWrapper):
171
+ return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
172
+ elif isinstance(x, torch.Tensor):
173
+ if x.dtype is torch.bfloat16:
174
+ return x.cpu().float().numpy()
175
+ return x.cpu().numpy()
176
+ else:
177
+ raise ValueError(f"Not a triton-compatible tensor: {x}")
178
+
179
+
180
+ def supports_tma(byval_only=False):
181
+ if is_interpreter():
182
+ return True
183
+ if not is_cuda():
184
+ return False
185
+ cuda_version = knobs.nvidia.ptxas.version
186
+ min_cuda_version = (12, 0) if byval_only else (12, 3)
187
+ cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
188
+ assert len(cuda_version_tuple) == 2, cuda_version_tuple
189
+ return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
190
+
191
+
192
+ def tma_skip_msg(byval_only=False):
193
+ if byval_only:
194
+ return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)"
195
+ else:
196
+ return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)"
197
+
198
+
199
+ requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
200
+
201
+
202
+ def default_alloc_fn(size: int, align: int, _):
203
+ return torch.empty(size, dtype=torch.int8, device="cuda")
204
+
205
+
206
+ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> torch.Tensor:
207
+ if isinstance(t, triton.runtime.jit.TensorWrapper):
208
+ return t.base
209
+ return t
210
+
211
+
212
+ def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
213
+ from triton import knobs
214
+
215
+ if skipped_attr is None:
216
+ skipped_attr = set()
217
+
218
+ monkeypatch = pytest.MonkeyPatch()
219
+
220
+ knobs_map = {
221
+ name: knobset
222
+ for name, knobset in knobs.__dict__.items()
223
+ if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
224
+ }
225
+
226
+ # We store which variables we need to unset below in finally because
227
+ # monkeypatch doesn't appear to reset variables that were never set
228
+ # before the monkeypatch.delenv call below.
229
+ env_to_unset = []
230
+ prev_propagate_env = knobs.propagate_env
231
+
232
+ def fresh_function():
233
+ nonlocal env_to_unset
234
+ for name, knobset in knobs_map.items():
235
+ setattr(knobs, name, knobset.copy().reset())
236
+ for knob in knobset.knob_descriptors.values():
237
+ if knob.key in os.environ:
238
+ monkeypatch.delenv(knob.key, raising=False)
239
+ else:
240
+ env_to_unset.append(knob.key)
241
+ knobs.propagate_env = True
242
+ return knobs
243
+
244
+ def reset_function():
245
+ for name, knobset in knobs_map.items():
246
+ setattr(knobs, name, knobset)
247
+ # `undo` should be placed before `del os.environ`
248
+ # Otherwise, it may restore environment variables that monkeypatch deleted
249
+ monkeypatch.undo()
250
+ for k in env_to_unset:
251
+ if k in os.environ:
252
+ del os.environ[k]
253
+ knobs.propagate_env = prev_propagate_env
254
+
255
+ return fresh_function, reset_function
triton/_utils.py ADDED
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import reduce
4
+ from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict
5
+
6
+ if TYPE_CHECKING:
7
+ from .language import core
8
+ IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
9
+ ObjPath = tuple[int, ...]
10
+
11
+ TRITON_MAX_TENSOR_NUMEL = 1048576
12
+
13
+
14
+ def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
15
+ return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
16
+
17
+
18
+ def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
19
+ from .language import core
20
+ assert len(path) != 0
21
+ prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
22
+ assert isinstance(prev, core.tuple)
23
+ prev._setitem(path[-1], val)
24
+
25
+
26
+ def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
27
+ from .language import core
28
+ is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
29
+ # We need to use dict so that ordering is maintained, while set doesn't guarantee order
30
+ ret: dict[ObjPath, None] = {}
31
+
32
+ def _impl(path: tuple[int, ...], current: Any):
33
+ if is_iterable(current):
34
+ for idx, item in enumerate(current):
35
+ _impl((*path, idx), item)
36
+ elif pred(path, current):
37
+ ret[path] = None
38
+
39
+ _impl((), iterable)
40
+
41
+ return list(ret.keys())
42
+
43
+
44
+ def is_power_of_two(x):
45
+ return (x & (x - 1)) == 0
46
+
47
+
48
+ def validate_block_shape(shape: List[int]):
49
+ numel = 1
50
+ for i, d in enumerate(shape):
51
+ if not isinstance(d, int):
52
+ raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
53
+ if not is_power_of_two(d):
54
+ raise ValueError(f"Shape element {i} must be a power of 2")
55
+ numel *= d
56
+
57
+ if numel > TRITON_MAX_TENSOR_NUMEL:
58
+ raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
59
+ return numel
60
+
61
+
62
+ type_canonicalisation_dict = {
63
+ # we canonicalise all bools to be unsigned:
64
+ "bool": "u1",
65
+ "int1": "u1",
66
+ "uint1": "u1",
67
+ "i1": "u1",
68
+ # floating-point dtypes:
69
+ "float8e4nv": "fp8e4nv",
70
+ "float8e5": "fp8e5",
71
+ "float8e4b15": "fp8e4b15",
72
+ "float8_e4m3fn": "fp8e4nv",
73
+ "float8e4b8": "fp8e4b8",
74
+ "float8_e4m3fnuz": "fp8e4b8",
75
+ "float8_e5m2": "fp8e5",
76
+ "float8e5b16": "fp8e5b16",
77
+ "float8_e5m2fnuz": "fp8e5b16",
78
+ "half": "fp16",
79
+ "float16": "fp16",
80
+ "bfloat16": "bf16",
81
+ "float": "fp32",
82
+ "float32": "fp32",
83
+ "double": "fp64",
84
+ "float64": "fp64",
85
+ # signed integers:
86
+ "int8": "i8",
87
+ "int16": "i16",
88
+ "int": "i32",
89
+ "int32": "i32",
90
+ "int64": "i64",
91
+ # unsigned integers:
92
+ "uint8": "u8",
93
+ "uint16": "u16",
94
+ "uint32": "u32",
95
+ "uint64": "u64",
96
+ "void": "void",
97
+ }
98
+
99
+ for v in list(type_canonicalisation_dict.values()):
100
+ type_canonicalisation_dict[v] = v
101
+
102
+
103
+ def canonicalize_dtype(dtype):
104
+ dtype_str = str(dtype).split(".")[-1]
105
+ return type_canonicalisation_dict[dtype_str]
106
+
107
+
108
+ BITWIDTH_DICT: Dict[str, int] = {
109
+ **{f"u{n}": n
110
+ for n in (1, 8, 16, 32, 64)},
111
+ **{f"i{n}": n
112
+ for n in (1, 8, 16, 32, 64)},
113
+ **{f"fp{n}": n
114
+ for n in (16, 32, 64)},
115
+ **{f"fp8{suffix}": 8
116
+ for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")},
117
+ "bf16": 16,
118
+ "void": 0,
119
+ }
120
+
121
+ for k, v in type_canonicalisation_dict.items():
122
+ BITWIDTH_DICT[k] = BITWIDTH_DICT[v]
123
+
124
+
125
+ def get_primitive_bitwidth(dtype: str) -> int:
126
+ return BITWIDTH_DICT[dtype]
@@ -0,0 +1,47 @@
1
+ import importlib
2
+ import inspect
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from typing import Type, TypeVar, Union
6
+ from types import ModuleType
7
+ from .driver import DriverBase
8
+ from .compiler import BaseBackend
9
+
10
+ if sys.version_info >= (3, 10):
11
+ from importlib.metadata import entry_points
12
+ else:
13
+ from importlib_metadata import entry_points
14
+
15
+ T = TypeVar("T", bound=Union[BaseBackend, DriverBase])
16
+
17
+
18
+ def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]:
19
+ ret: list[Type[T]] = []
20
+ for attr_name in dir(module):
21
+ attr = getattr(module, attr_name)
22
+ if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr):
23
+ ret.append(attr)
24
+ if len(ret) == 0:
25
+ raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}")
26
+ if len(ret) > 1:
27
+ raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}")
28
+ return ret[0]
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class Backend:
33
+ compiler: Type[BaseBackend]
34
+ driver: Type[DriverBase]
35
+
36
+
37
+ def _discover_backends() -> dict[str, Backend]:
38
+ backends = dict()
39
+ for ep in entry_points().select(group="triton.backends"):
40
+ compiler = importlib.import_module(f"{ep.value}.compiler")
41
+ driver = importlib.import_module(f"{ep.value}.driver")
42
+ backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore
43
+ _find_concrete_subclasses(driver, DriverBase)) # type: ignore
44
+ return backends
45
+
46
+
47
+ backends: dict[str, Backend] = _discover_backends()
File without changes