triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1414 @@
1
+ from __future__ import annotations
2
+ import ast
3
+ import textwrap
4
+ import inspect
5
+ from typing import Tuple, List, Dict, Callable
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import triton
11
+ import triton.language as tl
12
+ import dataclasses
13
+ from dataclasses import dataclass
14
+
15
+ from triton.language.semantic import TritonSemantic
16
+ from triton.tools.tensor_descriptor import TensorDescriptor
17
+ from .errors import InterpreterError
18
+ from functools import partial
19
+ from .._C.libtriton import interpreter as _interpreter
20
+ from .._C.libtriton import ir as _ir
21
+
22
+
23
+ @dataclass
24
+ class TensorHandle:
25
+ '''
26
+ data: numpy array
27
+ dtype: triton type, either pointer_type or scalar_type.
28
+ we don't store block_type here because the shape information is already available in the data field
29
+ attr: a dictionary of attributes
30
+ '''
31
+ data: np.array
32
+ dtype: tl.dtype
33
+ attr: Dict = dataclasses.field(default_factory=dict)
34
+
35
+ def __bool__(self):
36
+ return bool(self.data.all())
37
+
38
+ def get_element_ty(self):
39
+ dtype = self.dtype
40
+ while hasattr(dtype, "element_ty"):
41
+ dtype = dtype.element_ty
42
+ return dtype
43
+
44
+ def clone(self):
45
+ return TensorHandle(self.data.copy(), self.dtype)
46
+
47
+ def set_attr(self, key, value):
48
+ self.attr[key] = value
49
+
50
+
51
+ class BlockPointerHandle:
52
+
53
+ def __init__(self, base, shape, strides, offsets, block_shape, order):
54
+ self.base = base
55
+ self.shape = shape
56
+ self.strides = strides
57
+ self.offsets = offsets
58
+ self.block_shape = block_shape
59
+ self.order = order
60
+
61
+ def materialize_pointers(self, boundary_check):
62
+ dtype_tt = self.base.get_element_ty()
63
+ n_bytes = dtype_tt.primitive_bitwidth // 8
64
+ ptrs = np.broadcast_to(self.base.data, self.block_shape)
65
+ masks = np.ones(self.block_shape, dtype=bool)
66
+ for dim in range(len(self.block_shape)):
67
+ bcast_dims = [1] * len(self.block_shape)
68
+ bcast_dims[dim] = self.block_shape[dim]
69
+ off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
70
+ ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
71
+ if dim in boundary_check:
72
+ masks = masks & (off < self.shape[dim].data) & (off >= 0)
73
+ ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
74
+ return ptrs, masks
75
+
76
+
77
+ class TensorDescHandle:
78
+
79
+ def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
80
+ block_shape: List[int], padding):
81
+ self.base = base
82
+ self.ndim = len(shape)
83
+ self.shape = shape
84
+ self.strides = strides
85
+ self.block_shape = block_shape
86
+ self.padding = padding
87
+
88
+ def validate(self):
89
+ assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
90
+ assert len(self.strides) == self.ndim
91
+ assert len(self.block_shape) == self.ndim
92
+ assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
93
+
94
+ for stride in self.strides[:-1]:
95
+ assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
96
+ assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
97
+
98
+ def materialize_pointers(self, offsets: List[TensorHandle]):
99
+ assert len(offsets) == self.ndim
100
+ scalar_ty = self.base.dtype.element_ty
101
+ itemsize = scalar_ty.primitive_bitwidth // 8
102
+ assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned"
103
+
104
+ ptrs = np.broadcast_to(self.base.data, self.block_shape)
105
+ masks = np.ones(self.block_shape, dtype=bool)
106
+ for dim in range(len(self.block_shape)):
107
+ bcast_dims = [1] * len(self.block_shape)
108
+ bcast_dims[dim] = self.block_shape[dim]
109
+ off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
110
+ ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
111
+ masks = masks & (0 <= off) & (off < self.shape[dim].data)
112
+ assert ptrs.dtype == np.uint64
113
+ ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
114
+ return ptrs, masks
115
+
116
+
117
+ @dataclass(frozen=True)
118
+ class InterpreterOptions:
119
+ extern_libs: dict = None
120
+ debug: bool = False
121
+ sanitize_overflow: bool = True
122
+ arch: str = None
123
+ supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
124
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
125
+ default_dot_input_precision: str = "tf32"
126
+ allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
127
+ max_num_imprecise_acc_default: int = 0
128
+ backend_name: str = "interpreter"
129
+
130
+
131
+ def _get_signed_np_dtype(dtype):
132
+ if dtype == np.uint8:
133
+ return np.int8
134
+ if dtype == np.uint16:
135
+ return np.int16
136
+ if dtype == np.uint32:
137
+ return np.int32
138
+ if dtype == np.uint64:
139
+ return np.int64
140
+ return dtype
141
+
142
+
143
+ def _get_np_dtype(tt_dtype):
144
+ if isinstance(tt_dtype, tl.pointer_type):
145
+ return np.dtype(np.uint64)
146
+ np_types = {
147
+ tl.int1: np.dtype(bool),
148
+ tl.float16: np.dtype(np.float16),
149
+ tl.float32: np.dtype(np.float32),
150
+ tl.float64: np.dtype(np.float64),
151
+ tl.int8: np.dtype(np.int8),
152
+ tl.uint8: np.dtype(np.uint8),
153
+ tl.int16: np.dtype(np.int16),
154
+ tl.uint16: np.dtype(np.uint16),
155
+ tl.int32: np.dtype(np.int32),
156
+ tl.uint32: np.dtype(np.uint32),
157
+ tl.int64: np.dtype(np.int64),
158
+ tl.uint64: np.dtype(np.uint64),
159
+ # bfloat16 types are stored as uint16
160
+ tl.bfloat16: np.dtype(np.uint16),
161
+ # float8 types are stored as uint8
162
+ tl.float8e5: np.dtype(np.uint8),
163
+ tl.float8e5b16: np.dtype(np.uint8),
164
+ tl.float8e4nv: np.dtype(np.uint8),
165
+ tl.float8e4b8: np.dtype(np.uint8),
166
+ tl.float8e4b15: np.dtype(np.uint8),
167
+ }
168
+ if isinstance(tt_dtype, tl.block_type):
169
+ if isinstance(tt_dtype.element_ty, tl.pointer_type):
170
+ return np.dtype(np.uint64)
171
+ return np_types[tt_dtype.element_ty]
172
+ return np_types[tt_dtype]
173
+
174
+
175
+ def _convert_float(input, input_dtype, output_dtype, rounding_mode):
176
+ input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}")
177
+ output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}")
178
+ input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype)
179
+ sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01
180
+ input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1
181
+ output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1
182
+ significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1)
183
+ bias_input = input_dtype.exponent_bias
184
+ bias_output = output_dtype.exponent_bias
185
+ exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
186
+ subnormal_index = exponent == 0
187
+ if np.any(subnormal_index):
188
+ # Credit to Phil: phil@openai.com
189
+ # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn))
190
+ # where m0, m1, ..., mn are the 1-bit of the mantissa
191
+ # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0))
192
+ bit_pos = np.zeros_like(input_bin, dtype=np.int32)
193
+ # Find the most significant bit of the mantissa in the significand
194
+ for i in range(input_dtype.fp_mantissa_width):
195
+ bit_index = ((significand >> i) & 0x01)
196
+ # pos should be >= 1
197
+ bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i
198
+ zero_significand_index = significand == 0
199
+ exponent[subnormal_index] = 1 - bit_pos[subnormal_index]
200
+ # 0 significand and subnormal should be treated as 0
201
+ exponent[zero_significand_index & subnormal_index] = bias_input - bias_output
202
+ significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & (
203
+ (1 << input_dtype.fp_mantissa_width) - 1)
204
+ # Prevent overflow and underflow
205
+ exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1))
206
+ exponent_output = exponent_output.astype(output_unint_dtype)
207
+ sign_output = sign.astype(output_unint_dtype)
208
+ if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast
209
+ significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & (
210
+ (1 << output_dtype.fp_mantissa_width) - 1)
211
+ if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even
212
+ # find the cut-off bit
213
+ cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1))
214
+ significand_output = significand_output + (cut_off > 0)
215
+ significand_output = significand_output.astype(output_unint_dtype)
216
+ else: # Upcast
217
+ significand_output = (significand.astype(output_unint_dtype) <<
218
+ (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & (
219
+ (1 << output_dtype.fp_mantissa_width) - 1)
220
+ subnormal_index = exponent_output == 0
221
+ if np.any(subnormal_index): # underflow
222
+ # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn))
223
+ # where m0, m1, ..., mn are the 1-bit of the mantissa
224
+ # shift = (1 - exp_bias_output) - (exp - exp_bias_input)
225
+ # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift))
226
+ exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
227
+ non_zero_exponent_index = exponent != 0
228
+ # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa
229
+ subnormal_index = subnormal_index & non_zero_exponent_index
230
+ shift = np.zeros_like(input_bin, dtype=np.int32)
231
+ shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input)
232
+ significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | (
233
+ 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index]))
234
+ output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | (
235
+ exponent_output << output_dtype.fp_mantissa_width) | significand_output
236
+ return output.reshape(input.shape)
237
+
238
+
239
+ def _erf(x):
240
+ # Numpy does not support erf
241
+ return math.erf(x)
242
+
243
+
244
+ def _umulhi_64(a, b):
245
+ # Numpy does not support 128-bit multiplication
246
+ # So we have to implement it manually
247
+ return (int(a) * int(b)) >> 64
248
+
249
+
250
+ np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32])
251
+ np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64])
252
+ np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
253
+
254
+
255
+ class ExtraFunctions:
256
+
257
+ @staticmethod
258
+ def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
259
+ return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
260
+
261
+
262
+ class InterpreterBuilder:
263
+ ir_sem_to_interpreter_sem = {
264
+ _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE,
265
+ _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE,
266
+ _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED,
267
+ _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE,
268
+ }
269
+
270
+ ir_rmw_op_to_interpreter_rmw_op = {
271
+ _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD,
272
+ _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD,
273
+ _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN,
274
+ _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN,
275
+ _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX,
276
+ _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX,
277
+ _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND,
278
+ _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR,
279
+ _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR,
280
+ _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG,
281
+ }
282
+
283
+ def __init__(self) -> None:
284
+ self.arch = None
285
+ self.options = InterpreterOptions()
286
+ self.codegen_fns = {}
287
+ self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types
288
+ self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1)
289
+
290
+ def set_grid_idx(self, x, y, z):
291
+ if not x < self.grid_dim[0]:
292
+ raise ValueError("x >= grid_dim[0]")
293
+ if not y < self.grid_dim[1]:
294
+ raise ValueError("y >= grid_dim[1]")
295
+ if not z < self.grid_dim[2]:
296
+ raise ValueError("z >= grid_dim[2]")
297
+ self.grid_idx = (x, y, z)
298
+
299
+ def set_grid_dim(self, nx, ny, nz):
300
+ self.grid_dim = (nx, ny, nz)
301
+
302
+ # constants
303
+
304
+ def get_half_ty(self):
305
+ return tl.float16
306
+
307
+ def get_bf16_ty(self):
308
+ return tl.bfloat16
309
+
310
+ def get_float_ty(self):
311
+ return tl.float32
312
+
313
+ def get_double_ty(self):
314
+ return tl.float64
315
+
316
+ def get_int1_ty(self):
317
+ return tl.int1
318
+
319
+ def get_int8_ty(self):
320
+ return tl.int8
321
+
322
+ def get_uint8_ty(self):
323
+ return tl.uint8
324
+
325
+ def get_int16_ty(self):
326
+ return tl.int16
327
+
328
+ def get_uint16_ty(self):
329
+ return tl.uint16
330
+
331
+ def get_int32_ty(self):
332
+ return tl.int32
333
+
334
+ def get_uint32_ty(self):
335
+ return tl.uint32
336
+
337
+ def get_int64_ty(self):
338
+ return tl.int64
339
+
340
+ def get_uint64_ty(self):
341
+ return tl.uint64
342
+
343
+ def get_fp8e4nv_ty(self):
344
+ return tl.float8e4nv
345
+
346
+ def get_fp8e4b15_ty(self):
347
+ return tl.float8e4b15
348
+
349
+ def get_fp8e4b8_ty(self):
350
+ return tl.float8e4b8
351
+
352
+ def get_fp8e5_ty(self):
353
+ return tl.float8e5
354
+
355
+ def get_fp8e5b16_ty(self):
356
+ return tl.float8e5b16
357
+
358
+ def get_ptr_ty(self, elt_ty, addr_space):
359
+ return tl.pointer_type(elt_ty, addr_space)
360
+
361
+ def get_block_ty(self, dtype, shape):
362
+ return tl.block_type(dtype, shape)
363
+
364
+ def get_int1(self, value):
365
+ return TensorHandle(np.array([value], dtype=np.bool_), tl.int1)
366
+
367
+ def get_uint8(self, value):
368
+ return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8)
369
+
370
+ def get_int8(self, value):
371
+ return TensorHandle(np.array([value], dtype=np.int8), tl.int8)
372
+
373
+ def get_uint16(self, value):
374
+ return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16)
375
+
376
+ def get_int16(self, value):
377
+ return TensorHandle(np.array([value], dtype=np.int16), tl.int16)
378
+
379
+ def get_uint32(self, value):
380
+ return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32)
381
+
382
+ def get_int32(self, value):
383
+ return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
384
+
385
+ def get_uint64(self, value):
386
+ return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64)
387
+
388
+ def get_int64(self, value):
389
+ return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
390
+
391
+ def get_fp16(self, value):
392
+ return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
393
+
394
+ def get_fp32(self, value):
395
+ return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
396
+
397
+ def get_fp64(self, value):
398
+ return TensorHandle(np.array([value], dtype=np.float64), tl.float64)
399
+
400
+ def get_null_value(self, type):
401
+ return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type)
402
+
403
+ # programming model
404
+ def create_get_program_id(self, axis):
405
+ if self.grid_idx is None:
406
+ raise ValueError("grid_idx is None")
407
+ return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
408
+
409
+ def create_get_num_programs(self, axis):
410
+ return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
411
+
412
+ # memory ops
413
+ def create_load(self, ptr, _0, _1, is_volatile):
414
+ mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
415
+ other = None
416
+ return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
417
+
418
+ def create_store(self, ptr, val, _0, _1):
419
+ mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
420
+ return self.create_masked_store(ptr, val, mask, None, None)
421
+
422
+ def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
423
+ dtype_tt = ptrs.get_element_ty()
424
+ dtype_np = _get_np_dtype(dtype_tt)
425
+ if other is None:
426
+ other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
427
+ ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
428
+ return TensorHandle(ret, dtype_tt)
429
+
430
+ def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
431
+ return _interpreter.store(ptrs.data, value.data, mask.data)
432
+
433
+ # casting ops
434
+ def cast_impl(self, src, dst_type):
435
+ src_element_type = src.dtype.scalar
436
+ dst_element_type = dst_type.scalar
437
+ if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \
438
+ (src_element_type == tl.float32 and dst_element_type == tl.bfloat16):
439
+ data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type))
440
+ return TensorHandle(data, dst_type.scalar)
441
+ else:
442
+ return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar)
443
+
444
+ create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
445
+ create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
446
+ create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
447
+ create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
448
+ create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
449
+ create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
450
+ create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
451
+
452
+ def create_fp_to_fp(self, src, dst_type, rounding_mode):
453
+ src_element_type = src.dtype.scalar
454
+ dst_element_type = dst_type.scalar
455
+ data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type))
456
+ return TensorHandle(data, dst_type.scalar)
457
+
458
+ def create_bitcast(self, src, dst_type):
459
+ return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar)
460
+
461
+ # binary operators
462
+ def binary_op(self, lhs, rhs, op):
463
+ return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar)
464
+
465
+ create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
466
+ create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
467
+ create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
468
+ create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
469
+ create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
470
+ create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
471
+ create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
472
+ create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
473
+ create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
474
+ # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
475
+ create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
476
+ create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
477
+ create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
478
+ create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
479
+ create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
480
+ create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
481
+ create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
482
+ create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
483
+ create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
484
+ create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
485
+ create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
486
+ create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
487
+ create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
488
+ create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
489
+ create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
490
+ create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
491
+ create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
492
+ create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
493
+ create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
494
+ create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
495
+ create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
496
+ create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
497
+ create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
498
+ create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
499
+ create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
500
+ create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
501
+ create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
502
+ create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
503
+ create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
504
+ create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
505
+ create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
506
+ create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
507
+ create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
508
+ create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
509
+ create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
510
+ create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
511
+ create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
512
+ create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
513
+ create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
514
+ create_int_to_ptr = create_bitcast
515
+ create_ptr_to_int = create_bitcast
516
+
517
+ def create_idiv(self, lhs, rhs):
518
+ # Triton has IEEE, not numpy/torch, semantics for %, and those carry
519
+ # through to //, so we have to use a nonstandard expression to get a
520
+ # reference result for //.
521
+ return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar)
522
+
523
+ def create_ashr(self, lhs, rhs):
524
+ # Triton's rshift operator depends on the signedness of the left operand
525
+ lhs_dtype = _get_signed_np_dtype(lhs.data.dtype)
526
+ rhs_dtype = _get_signed_np_dtype(rhs.data.dtype)
527
+ lhs.data = lhs.data.astype(lhs_dtype)
528
+ rhs.data = rhs.data.astype(rhs_dtype)
529
+ return self.binary_op(lhs, rhs, np.right_shift)
530
+
531
+ def create_umulhi(self, lhs, rhs):
532
+ dtype = lhs.data.dtype
533
+ if dtype == np.int64 or dtype == np.uint64:
534
+ return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar)
535
+ else:
536
+ compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}")
537
+ lhs_data = lhs.data.astype(compute_dtype)
538
+ rhs_data = rhs.data.astype(compute_dtype)
539
+ ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8)
540
+ return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar)
541
+
542
+ # ternary functions
543
+ def ternary_op(self, lhs, rhs, other, op):
544
+ return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar)
545
+
546
+ create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip)
547
+ create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
548
+
549
+ def create_fma(self, x, y, z):
550
+ return TensorHandle(x.data * y.data + z.data, z.dtype.scalar)
551
+
552
+ # unary functions
553
+ def unary_op(self, arg, op):
554
+ return TensorHandle(op(arg.data), arg.dtype.scalar)
555
+
556
+ def create_fabs(self, arg):
557
+ # Mask out the sign bit based on the primitive length
558
+ dtype_tt = arg.dtype
559
+ mask_bitwidth = dtype_tt.primitive_bitwidth - 1
560
+ np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}")
561
+ data = arg.data.view(np_uint_dtype)
562
+ mask = (1 << mask_bitwidth) - 1
563
+ ret = (data & mask).view(_get_np_dtype(dtype_tt))
564
+ return TensorHandle(ret, arg.dtype.scalar)
565
+
566
+ create_cos = lambda self, arg: self.unary_op(arg, np.cos)
567
+ create_exp = lambda self, arg: self.unary_op(arg, np.exp)
568
+ create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2)
569
+ create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
570
+ create_floor = lambda self, arg: self.unary_op(arg, np.floor)
571
+ create_ceil = lambda self, arg: self.unary_op(arg, np.ceil)
572
+ create_log = lambda self, arg: self.unary_op(arg, np.log)
573
+ create_log2 = lambda self, arg: self.unary_op(arg, np.log2)
574
+ create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
575
+ create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
576
+ create_sin = lambda self, arg: self.unary_op(arg, np.sin)
577
+
578
+ def create_erf(self, arg):
579
+ ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data)
580
+ return TensorHandle(ret, arg.dtype.scalar)
581
+
582
+ def create_rsqrt(self, arg):
583
+ return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar)
584
+
585
+ # tensor operators
586
+ create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar)
587
+
588
+ def create_trans(self, arg, perm):
589
+ return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar)
590
+
591
+ def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc):
592
+ a_data = a.data
593
+ b_data = b.data
594
+ if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \
595
+ (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()):
596
+ a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16)
597
+ b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
598
+ return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
599
+
600
+ def create_make_range(self, ret_ty, start, stop):
601
+ return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
602
+
603
+ def create_histogram(self, data, bins, mask):
604
+ if mask is None:
605
+ mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
606
+ # force all masked elements to zero
607
+ data = np.where(mask.data, data.data, np.zeros_like(data.data))
608
+ histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
609
+ # remove overcounted elements
610
+ histogram[0] -= np.logical_not(mask.data).sum()
611
+ return TensorHandle(histogram, tl.int32)
612
+
613
+ def create_gather(self, src, indices, axis):
614
+ return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
615
+
616
+ # pointer arithmetic
617
+
618
+ def create_addptr(self, ptr, offset):
619
+ dtype_tt = ptr.get_element_ty()
620
+ element_bitwidth = dtype_tt.primitive_bitwidth
621
+ # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic
622
+ element_bytewidth = max(1, element_bitwidth // 8)
623
+ return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype)
624
+
625
+ def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
626
+ is_volatile):
627
+ ptrs, masks = ptr.materialize_pointers(boundary_check)
628
+ dtype_tt = ptrs.get_element_ty()
629
+ dtype_np = _get_np_dtype(dtype_tt)
630
+ if padding_option is None:
631
+ other = None
632
+ elif padding_option == _ir.PADDING_OPTION.PAD_ZERO:
633
+ other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
634
+ elif padding_option == _ir.PADDING_OPTION.PAD_NAN:
635
+ other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
636
+ else:
637
+ raise ValueError(f"unsupported padding option {padding_option}")
638
+ return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
639
+
640
+ def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
641
+ ptrs, masks = ptr.materialize_pointers(boundary_check)
642
+ return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
643
+
644
+ def create_expand_dims(self, arg, axis):
645
+ return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar)
646
+
647
+ def create_broadcast(self, arg, shape):
648
+ return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar)
649
+
650
+ def create_cat(self, lhs, rhs):
651
+ return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar)
652
+
653
+ def create_join(self, lhs, rhs):
654
+ # Triton only supports joining two original tensors into a new one along the last axis
655
+ return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar)
656
+
657
+ def create_split(self, val):
658
+ # Triton only supports splitting the original tensor into two along the last axis
659
+ return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
660
+
661
+ def create_splat(self, ret_ty, arg):
662
+ shape = ret_ty.shape
663
+ if isinstance(arg.dtype, tl.block_type):
664
+ return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
665
+ else: # scalar
666
+ return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
667
+
668
+ def create_unsplat(self, arg):
669
+ return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
670
+
671
+ def create_atomic_cas(self, ptr, cmp, val, sem, scope):
672
+ if sem not in self.ir_sem_to_interpreter_sem:
673
+ raise ValueError(f"unsupported semantic {sem}")
674
+ sem = self.ir_sem_to_interpreter_sem[sem]
675
+ return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar)
676
+
677
+ def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope):
678
+ if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op:
679
+ raise ValueError(f"unsupported rmwOp {rmwOp}")
680
+ if sem not in self.ir_sem_to_interpreter_sem:
681
+ raise ValueError(f"unsupported semantic {sem}")
682
+ rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp]
683
+ sem = self.ir_sem_to_interpreter_sem[sem]
684
+ return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar)
685
+
686
+ def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
687
+ raise NotImplementedError("extern_elementwise not supported in interpreter mode")
688
+
689
+ def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
690
+ raise NotImplementedError("inline_asm not supported in interpreter mode")
691
+
692
+ def create_print(self, prefix, hex, values, isSigned):
693
+ # NOTE: the `isSigned` variable is not really used here; because Signness is already known
694
+ # by `values` themselves in python interpreter, thus not really needed here;
695
+ # it is only used for triton PrintOpToLLVM to correctly construct the format specifier.
696
+ # Interpreter's device_print function has a different format than Triton's device_print
697
+ msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})"
698
+ if prefix:
699
+ msg += f" {prefix}"
700
+ if hex:
701
+ np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"})
702
+ for value in values:
703
+ print(msg + f" {value.data}")
704
+ if hex:
705
+ np.set_printoptions(formatter=None)
706
+
707
+ def create_assert(self, condition, message):
708
+ # Interpreter's device_assert function has a different format than Triton's device_assert
709
+ assert condition, f"{message}"
710
+
711
+ def create_assume(self, condition):
712
+ assert condition, "Assume failed"
713
+
714
+ def create_barrier(self):
715
+ # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
716
+ pass
717
+
718
+ def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
719
+ # Create new offsets to avoid modifying the original
720
+ new_offsets = [offset.clone() for offset in offsets]
721
+ return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
722
+
723
+ def create_advance(self, ptr, offsets):
724
+ if len(ptr.offsets) != len(offsets):
725
+ raise ValueError("len(ptr.offsets) != len(offsets)")
726
+ # Create new offsets to avoid modifying the original
727
+ new_offsets = [offset.clone() for offset in ptr.offsets]
728
+ ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
729
+ for i in range(len(offsets)):
730
+ ret.offsets[i].data += offsets[i].data
731
+ return ret
732
+
733
+ def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
734
+ tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
735
+ desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
736
+ desc.validate()
737
+ return desc
738
+
739
+ def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier,
740
+ eviction_policy):
741
+ assert isinstance(desc, TensorDescHandle)
742
+ ptrs, mask = desc.materialize_pointers(indices)
743
+ dtype_tt = ptrs.get_element_ty()
744
+ dtype_np = _get_np_dtype(dtype_tt)
745
+ padding = desc.padding
746
+ if padding == _ir.PADDING_OPTION.PAD_ZERO:
747
+ other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
748
+ elif padding == _ir.PADDING_OPTION.PAD_NAN:
749
+ other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
750
+ else:
751
+ raise ValueError(f"unsupported padding {padding}")
752
+ return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
753
+ eviction_policy=eviction_policy, is_volatile=False)
754
+
755
+ def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
756
+ ptrs, mask = desc.materialize_pointers(indices)
757
+ return self.create_masked_store(ptrs, value, mask, None, None)
758
+
759
+ def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type):
760
+ dtype = desc.base.dtype.element_ty
761
+ np_dtype = _get_np_dtype(dtype)
762
+ result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype)
763
+ cache_modifier = None
764
+ eviction_policy = None
765
+ for i, x_offset in enumerate(x_offsets.data):
766
+ indices = [TensorHandle(x_offset, tl.int32), y_offset]
767
+ result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data
768
+ return TensorHandle(result, dtype)
769
+
770
+ def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle,
771
+ y_offset: TensorHandle):
772
+ for i, x_offset in enumerate(x_offsets.data):
773
+ slice = TensorHandle(value.data[i], value.dtype)
774
+ indices = [TensorHandle(x_offset, tl.int32), y_offset]
775
+ self.create_descriptor_store(desc, slice, indices)
776
+
777
+ def get_all_ones_value(self, type):
778
+ np_type = _get_np_dtype(type)
779
+ if "int" in np_type.name:
780
+ return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
781
+ elif np_type == np.bool_:
782
+ return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
783
+ else:
784
+ raise TypeError(f"unsupported type {type}")
785
+
786
+
787
+ def _patch_attr(obj, name, member, builder):
788
+ semantic = TritonSemantic(builder)
789
+ new_member = lambda *args, member=member, **kwargs: (member(*args, **
790
+ {k: v
791
+ for k, v in kwargs.items()
792
+ if k != "_semantic"}, _semantic=semantic))
793
+ setattr(obj, name, new_member)
794
+
795
+
796
+ def _patch_builtin(pkg, builder):
797
+ for name, member in inspect.getmembers(pkg):
798
+ if tl.core.is_builtin(member):
799
+ _patch_attr(pkg, name, member, builder)
800
+
801
+
802
+ def _patch_lang_tensor(tensor):
803
+
804
+ def _get_bool(self):
805
+ data = self.handle.data
806
+ # in triton, only scalars can be converted to booleans
807
+ # here we need this hack because all scalars are tensors
808
+ return bool(data) if data.size == 1 else True
809
+
810
+ def _get_transpose(self):
811
+ handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype)
812
+ assert self.type.is_block()
813
+ block_shape = list(self.type.shape)
814
+ block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1]
815
+ res_ty = tl.core.block_type(self.dtype, block_shape)
816
+ return tl.core.tensor(handle, res_ty)
817
+
818
+ tensor.__index__ = lambda self: int(self.handle.data)
819
+ tensor.__bool__ = lambda self: _get_bool(self)
820
+ tensor.__repr__ = lambda self: repr(self.handle.data)
821
+ tensor.__str__ = lambda self: str(self.handle.data)
822
+ tensor.T = property(_get_transpose)
823
+
824
+
825
+ class ReduceScanOpInterface:
826
+
827
+ def __init__(self, axis, combine_fn):
828
+ self.axis = axis
829
+ self.combine_fn = combine_fn
830
+
831
+ def check_axis(self, shape, axis):
832
+ if axis is not None and axis >= len(shape):
833
+ raise ValueError(f"axis {axis} out of bounds for shape {shape}")
834
+
835
+ def check_tensor(self, input):
836
+ for arg in input:
837
+ if not isinstance(arg, tl.core.tensor):
838
+ raise ValueError(f"input must be a tensor, got {type(arg)}")
839
+ self.check_axis(arg.shape, self.axis)
840
+
841
+ def to_tensor(self, ret, dtype):
842
+ np_dtype = _get_np_dtype(dtype)
843
+ if hasattr(ret, "shape") and ret.shape:
844
+ ret = ret.astype(np_dtype)
845
+ ret_type = tl.block_type(dtype, list(ret.shape))
846
+ else:
847
+ ret = np.array([ret], dtype=np_dtype)
848
+ ret_type = dtype
849
+ return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
850
+
851
+ def apply(self, input):
852
+ if not isinstance(input, tuple):
853
+ return self.apply((input, ))[0]
854
+ self.check_tensor(input)
855
+ ret = self.apply_impl(input)
856
+ return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
857
+
858
+
859
+ class ReduceOps(ReduceScanOpInterface):
860
+
861
+ def __init__(self, axis, combine_fn, keep_dims):
862
+ super().__init__(axis, combine_fn)
863
+ self.keep_dims = keep_dims
864
+
865
+ def unravel(self, input, axis):
866
+ ret = []
867
+ for data in input:
868
+ if axis is not None:
869
+ ret.append(data)
870
+ else:
871
+ axis = 0
872
+ ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype))
873
+ return tuple(ret), axis
874
+
875
+ def generic_reduce(self, input):
876
+ original_axis = self.axis
877
+ input, axis = self.unravel(input, self.axis)
878
+ input_data = []
879
+ output_data = []
880
+ input_shape = input[0].handle.data.shape
881
+ output_shape = input_shape[0:axis] + input_shape[axis + 1:]
882
+ for arg in input:
883
+ input_data.append(arg.handle.data)
884
+ output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype))
885
+ # Reduce on axis
886
+ for i in range(input_data[0].size):
887
+ # Recover input_index from i using input_shape
888
+ input_index = np.unravel_index(i, input_shape)
889
+ output_index = input_index[0:axis] + input_index[axis + 1:]
890
+ input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data))
891
+ if input_index[axis] == 0:
892
+ # First element
893
+ for j in range(len(output_data)):
894
+ output_data[j][output_index] = input_tuple[j].handle.data.item()
895
+ else:
896
+ acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data))
897
+ combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple)
898
+ acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
899
+ for j in range(len(output_data)):
900
+ output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance(
901
+ acc_tuple[j], tl.core.tensor) else acc_tuple[j]
902
+ # Pack output
903
+ ret = []
904
+ for i, data in enumerate(output_data):
905
+ if self.keep_dims:
906
+ if original_axis is not None:
907
+ data = np.expand_dims(data, axis)
908
+ else:
909
+ for _ in range(len(input_shape)):
910
+ data = np.expand_dims(data, 0)
911
+
912
+ elif original_axis is None:
913
+ # Take a scalar
914
+ data = data.item()
915
+ ret.append(self.to_tensor(data, input[i].dtype))
916
+ return ret
917
+
918
+ def min_max(self, input, val_reduce_op, idx_reduce_op=None):
919
+ # If input is a tuple, it must be (val, index), and we only take val
920
+ input = input[0] if isinstance(input, tuple) else input
921
+ val = None
922
+ idx = None
923
+ if val_reduce_op:
924
+ val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
925
+ if idx_reduce_op:
926
+ idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32)
927
+ if val is not None and idx is not None:
928
+ return val, idx
929
+ elif val is not None:
930
+ return val
931
+ elif idx is not None:
932
+ return idx
933
+ else:
934
+ raise ValueError("val_reduce_op and idx_reduce_op are both None")
935
+
936
+ def sum(self, input):
937
+ return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
938
+
939
+ def apply_impl(self, input):
940
+ if self.combine_fn == tl.standard._argmin_combine_tie_break_left:
941
+ return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin)
942
+ elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
943
+ return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
944
+ elif self.combine_fn == tl.standard._elementwise_max:
945
+ return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
946
+ elif self.combine_fn == tl.standard._elementwise_min:
947
+ return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
948
+ elif self.combine_fn == tl.standard._sum_combine:
949
+ return self.sum(input[0])
950
+ else:
951
+ # Fall back to the slow mode
952
+ return self.generic_reduce(input)
953
+
954
+
955
+ class ScanOps(ReduceScanOpInterface):
956
+
957
+ def __init__(self, axis, combine_fn, reverse):
958
+ super().__init__(axis, combine_fn)
959
+ self.reverse = reverse
960
+
961
+ def cumsum(self, input):
962
+ return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)]
963
+
964
+ def cumprod(self, input):
965
+ return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)]
966
+
967
+ def generic_scan(self, input):
968
+ input_data = []
969
+ output_data = []
970
+ shape = input[0].handle.data.shape
971
+ for arg in input:
972
+ input_data.append(arg.handle.data)
973
+ output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype))
974
+ # Scan on axis
975
+ for i in range(input_data[0].size):
976
+ # Recover index from i using shape
977
+ index = np.unravel_index(i, shape)
978
+ data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data))
979
+ if index[self.axis] == 0:
980
+ # First element
981
+ for j in range(len(output_data)):
982
+ output_data[j][index] = data[j].handle.data.item()
983
+ else:
984
+ prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index)))
985
+ acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data))
986
+ combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data)
987
+ acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
988
+ for j in range(len(output_data)):
989
+ output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance(
990
+ acc_tuple[j], tl.core.tensor) else acc_tuple[j]
991
+ # Pack output
992
+ ret = []
993
+ for i, data in enumerate(output_data):
994
+ ret.append(self.to_tensor(data, input[i].dtype))
995
+ return ret
996
+
997
+ def apply_impl(self, input):
998
+ new_input = []
999
+ if self.reverse:
1000
+ for arg in input:
1001
+ new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype))
1002
+ else:
1003
+ new_input = input
1004
+ if self.combine_fn == tl.standard._sum_combine:
1005
+ ret = self.cumsum(new_input[0])
1006
+ elif self.combine_fn == tl.standard._prod_combine:
1007
+ ret = self.cumprod(new_input[0])
1008
+ else:
1009
+ # Fall back to the slow mode
1010
+ ret = self.generic_scan(new_input)
1011
+ if self.reverse:
1012
+ for arg in ret:
1013
+ arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
1014
+ return ret
1015
+
1016
+
1017
+ def _patch_reduce_scan():
1018
+ # Because interpreter doesn't support region_builder_fn, we cannot patch the builder
1019
+ # to use the new reduce and scan functions.
1020
+ # Instead, we need to patch reduce and reduce functions in tl and tl.core
1021
+ def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs):
1022
+ return ReduceOps(axis, combine_fn, keep_dims).apply(input)
1023
+
1024
+ def _new_scan(input, axis, combine_fn, reverse=False, **kwargs):
1025
+ return ScanOps(axis, combine_fn, reverse).apply(input)
1026
+
1027
+ tl.reduce = _new_reduce
1028
+ tl.associative_scan = _new_scan
1029
+ tl.core.reduce = _new_reduce
1030
+ tl.core.associative_scan = _new_scan
1031
+
1032
+
1033
+ def _patch_lang_core(lang):
1034
+
1035
+ def _new_to_ir(self, builder):
1036
+ # We need to specify signedness for integer types in the numpy mode
1037
+ if self.name == 'void':
1038
+ return builder.get_void_ty()
1039
+ elif self.name == 'int1':
1040
+ return builder.get_int1_ty()
1041
+ elif self.name == 'int8':
1042
+ return builder.get_int8_ty()
1043
+ elif self.name == 'uint8':
1044
+ return builder.get_uint8_ty()
1045
+ elif self.name == 'int16':
1046
+ return builder.get_int16_ty()
1047
+ elif self.name == 'uint16':
1048
+ return builder.get_uint16_ty()
1049
+ elif self.name == 'int32':
1050
+ return builder.get_int32_ty()
1051
+ elif self.name == 'uint32':
1052
+ return builder.get_uint32_ty()
1053
+ elif self.name == 'int64':
1054
+ return builder.get_int64_ty()
1055
+ elif self.name == 'uint64':
1056
+ return builder.get_uint64_ty()
1057
+ elif self.name == 'fp8e5':
1058
+ return builder.get_fp8e5_ty()
1059
+ elif self.name == 'fp8e4nv':
1060
+ return builder.get_fp8e4nv_ty()
1061
+ elif self.name == 'fp8e4b15':
1062
+ return builder.get_fp8e4b15_ty()
1063
+ elif self.name == 'fp16':
1064
+ return builder.get_half_ty()
1065
+ elif self.name == 'bf16':
1066
+ return builder.get_bf16_ty()
1067
+ elif self.name == 'fp32':
1068
+ return builder.get_float_ty()
1069
+ elif self.name == 'fp64':
1070
+ return builder.get_double_ty()
1071
+ raise ValueError(f'fail to convert {self} to ir type')
1072
+
1073
+ # can't just map lang.static_range to `range`, because `tl.static_range`
1074
+ # can get `step` passed by keyword
1075
+ def _new_range(arg1, arg2=None, step=None, **kwargs):
1076
+ if step is None:
1077
+ step = 1
1078
+ if arg2 is None:
1079
+ start, end = 0, arg1
1080
+ else:
1081
+ start, end = arg1, arg2
1082
+ return range(start, end, step)
1083
+
1084
+ def _new_static_assert(cond, msg=""):
1085
+ assert cond, msg
1086
+
1087
+ def _set_attr(input, values, name):
1088
+ # skip non tensor types. This may happen for induction variables.
1089
+ if not isinstance(input, tl.tensor):
1090
+ return input
1091
+ # Unwrap constexpr
1092
+ values = [values] if not isinstance(values, (list, tuple)) else values
1093
+ values = [v.value if isinstance(v, tl.constexpr) else v for v in values]
1094
+ if len(values) != max(1, len(input.shape)):
1095
+ raise ValueError(f"len(values) != len(input.shape) for {name}")
1096
+ input.handle.set_attr(name, values)
1097
+ return input
1098
+
1099
+ lang.range = _new_range
1100
+ lang.static_range = _new_range
1101
+ lang.static_assert = _new_static_assert
1102
+ lang.static_print = print
1103
+ lang.dtype.to_ir = _new_to_ir
1104
+ lang.multiple_of = partial(_set_attr, name="tt.divisibility")
1105
+ lang.max_contiguous = partial(_set_attr, name="tt.contiguity")
1106
+ lang.max_constancy = partial(_set_attr, name="tt.constancy")
1107
+
1108
+ _patch_reduce_scan()
1109
+
1110
+
1111
+ def _patch_lang(fn):
1112
+ langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]]
1113
+ assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
1114
+ for lang in langs:
1115
+ _patch_builtin(lang, interpreter_builder)
1116
+ _patch_builtin(lang.tensor, interpreter_builder)
1117
+ if lang == tl:
1118
+ _patch_builtin(lang.math, interpreter_builder)
1119
+ _patch_lang_tensor(lang.tensor)
1120
+ _patch_lang_core(lang)
1121
+ _patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder)
1122
+
1123
+
1124
+ def _tuple_create(arg, contents):
1125
+ # NamedTuples and tuples have different construction semantics. NamedTuple
1126
+ # has a constructor that takes individual arguments, while tuple takes an
1127
+ # iterable. Both have type "tuple" making it difficult to distinguish
1128
+ # between them, but only NamedTuple has "_fields" and apparently this is how
1129
+ # everyone does the check.
1130
+ return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents)
1131
+
1132
+
1133
+ # TODO: wrap everything in triton tensors
1134
+ def _implicit_cvt(arg):
1135
+ if isinstance(arg, int):
1136
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
1137
+ dtype = np.int32
1138
+ if -2**31 <= arg < 2**31:
1139
+ dtype = np.int32
1140
+ elif 2**31 <= arg < 2**32:
1141
+ dtype = np.uint32
1142
+ elif -2**63 <= arg < 2**63:
1143
+ dtype = np.int64
1144
+ elif 2**63 <= arg < 2**64:
1145
+ dtype = np.uint64
1146
+ else:
1147
+ raise ValueError(f"Unsupported integer value {arg}")
1148
+ handle = TensorHandle(np.array([arg], dtype=dtype), ty)
1149
+ return tl.tensor(handle, ty)
1150
+ if hasattr(arg, "data_ptr"):
1151
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
1152
+ handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
1153
+ return tl.tensor(handle, ty)
1154
+ elif isinstance(arg, tuple):
1155
+ return _tuple_create(arg, map(_implicit_cvt, arg))
1156
+ elif isinstance(arg, TensorDescriptor):
1157
+ strides = [_implicit_cvt(s) for s in arg.strides]
1158
+ assert arg.strides[-1] == 1
1159
+ strides[-1] = tl.constexpr(1)
1160
+ semantic = TritonSemantic(InterpreterBuilder())
1161
+ return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
1162
+ shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
1163
+ block_shape=[tl.constexpr(b)
1164
+ for b in arg.block_shape], padding_option=arg.padding)
1165
+ return arg
1166
+
1167
+
1168
+ interpreter_builder = InterpreterBuilder()
1169
+ interpreter_semantic = TritonSemantic(interpreter_builder)
1170
+
1171
+
1172
+ def _unwrap_tensor(t):
1173
+ if isinstance(t, triton.runtime.jit.TensorWrapper):
1174
+ return t.base
1175
+ return t
1176
+
1177
+
1178
+ def _rewrap_tensor(t, original_tensor):
1179
+ if isinstance(original_tensor, triton.runtime.jit.TensorWrapper):
1180
+ return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype)
1181
+ return t
1182
+
1183
+
1184
+ class GridExecutor:
1185
+
1186
+ def __init__(self, fn, arg_names, grid):
1187
+ from .jit import _normalize_ty # TODO: modularize
1188
+
1189
+ self.fn = fn
1190
+ self.arg_names = arg_names
1191
+ self.grid = grid
1192
+ __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
1193
+ self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
1194
+
1195
+ def _init_args_hst(self, args_dev, kwargs):
1196
+ storages = {}
1197
+
1198
+ def _to_cpu(arg):
1199
+ if isinstance(arg, tuple):
1200
+ return _tuple_create(arg, map(_to_cpu, arg))
1201
+ elif isinstance(arg, TensorDescriptor):
1202
+ return TensorDescriptor(
1203
+ _to_cpu(arg.base),
1204
+ arg.shape,
1205
+ arg.strides,
1206
+ arg.block_shape,
1207
+ arg.padding,
1208
+ )
1209
+ elif not hasattr(arg, "data_ptr"):
1210
+ return arg
1211
+
1212
+ unwrapped_arg = _unwrap_tensor(arg)
1213
+ if unwrapped_arg.untyped_storage().data_ptr() not in storages:
1214
+ storage = unwrapped_arg.untyped_storage()
1215
+ storages[storage.data_ptr()] = storage.cpu()
1216
+
1217
+ storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
1218
+ cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
1219
+ cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
1220
+ cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
1221
+ return cpu_arg
1222
+
1223
+ args_hst = [_to_cpu(arg) for arg in args_dev]
1224
+
1225
+ # Process keyword arguments
1226
+ kwargs_hst = {}
1227
+ for key, value in kwargs.items():
1228
+ kwargs_hst[key] = _to_cpu(value)
1229
+ return args_hst, kwargs_hst
1230
+
1231
+ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
1232
+ storages = {}
1233
+
1234
+ def _from_cpu(arg_dev, arg_hst):
1235
+ if hasattr(arg_dev, "data_ptr"):
1236
+ # No need to rewrap because this just modifies internal
1237
+ arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
1238
+ storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
1239
+ elif isinstance(arg_dev, tuple):
1240
+ for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
1241
+ _from_cpu(arg_dev, arg_hst)
1242
+ elif isinstance(arg_dev, TensorDescriptor):
1243
+ _from_cpu(arg_dev.base, arg_hst.base)
1244
+
1245
+ for arg_dev, arg_hst in zip(args_dev, args_hst):
1246
+ _from_cpu(arg_dev, arg_hst)
1247
+
1248
+ # Restore keyword arguments
1249
+ for key, kwarg_dev in kwargs.items():
1250
+ kwarg_hst = kwargs_hst[key]
1251
+ _from_cpu(kwarg_dev, kwarg_hst)
1252
+
1253
+ for (arg_dev, arg_hst) in storages.values():
1254
+ arg_dev.copy_(arg_hst)
1255
+
1256
+ def __call__(self, *args_dev, **kwargs):
1257
+ if kwargs.pop("warmup", False):
1258
+ return
1259
+ # Removes not used reserved keywords from kwargs
1260
+ # Triton doesn't support keyword-only, variable positional or variable keyword arguments
1261
+ # It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
1262
+ argspec = inspect.getfullargspec(self.fn)
1263
+ kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
1264
+ # copy arguments to the host
1265
+ args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
1266
+ # remaps core language functions to interpreted ones
1267
+ _patch_lang(self.fn)
1268
+ # we need to copy arguments to the host for the interpreter
1269
+ # implicitly convert tensor arguments to their base pointers
1270
+ args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
1271
+ args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
1272
+ # iterate through grid
1273
+ grid = self.grid(args) if callable(self.grid) else self.grid
1274
+ assert len(grid) <= 3, "grid must have at most 3 dimensions"
1275
+ grid = grid + (1, ) * (3 - len(grid))
1276
+ interpreter_builder.set_grid_dim(*grid)
1277
+ try:
1278
+ for x in range(grid[0]):
1279
+ for y in range(grid[1]):
1280
+ for z in range(grid[2]):
1281
+ interpreter_builder.set_grid_idx(x, y, z)
1282
+ self.fn(**args)
1283
+ except Exception as e:
1284
+ if triton.knobs.compilation.front_end_debugging:
1285
+ raise
1286
+ raise InterpreterError(repr(e)) from e
1287
+ # copy arguments back to propagate side-effects
1288
+ self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
1289
+
1290
+
1291
+ class ASTTransformer(ast.NodeTransformer):
1292
+
1293
+ def visit_Assign(self, node):
1294
+ names = []
1295
+ for target in node.targets:
1296
+ names += [self.visit(target)]
1297
+ if len(names) > 1:
1298
+ raise ValueError("Multiple assignments are not supported")
1299
+ # Modify the assignment x = value to
1300
+ # interpreter_semantic.to_tensor(value, False)
1301
+ node.value = ast.Call(
1302
+ func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
1303
+ ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
1304
+ return node
1305
+
1306
+
1307
+ class FunctionRewriter:
1308
+ ast_transformer = ASTTransformer()
1309
+
1310
+ def __init__(self, fn, **kwargs):
1311
+ self.fn = fn
1312
+ self.kwargs = kwargs
1313
+ self.filename: str = ""
1314
+ # Absolute line number in the file
1315
+ self.def_file_lineno: int = 0
1316
+
1317
+ def rewrite_ast(self):
1318
+ # If exception is raise, it means the function does not have source code available,
1319
+ # e.g., dynamically generated functions, we cannot rewrite it so just return the original function
1320
+ try:
1321
+ lines, _ = inspect.getsourcelines(self.fn)
1322
+ except Exception:
1323
+ return self.fn
1324
+
1325
+ # truncate lines before def
1326
+ # @triton.autotune(...)
1327
+ # ...
1328
+ # @triton.jit
1329
+ # ...
1330
+ # def foo(...): <- this line is the function definition
1331
+ self.filename, self.def_file_lineno = self._get_jit_fn_file_line()
1332
+ self.def_lineno = self._find_def(lines)
1333
+ src = self._prepare_source(lines)
1334
+ transformed_ast = self._transform_ast(src)
1335
+ return self._compile_and_exec(transformed_ast)
1336
+
1337
+ def _get_jit_fn_file_line(self):
1338
+ from .jit import get_jit_fn_file_line, JITFunction
1339
+ return get_jit_fn_file_line(JITFunction(self.fn))
1340
+
1341
+ def _find_def(self, lines):
1342
+ def_lineno = 0
1343
+ # Line numbers start from 1
1344
+ for i, line in enumerate(lines):
1345
+ if line.strip().startswith("def "):
1346
+ def_lineno = i + 1
1347
+ return def_lineno
1348
+
1349
+ def _prepare_source(self, lines):
1350
+ lines = lines[self.def_lineno - 1:]
1351
+ src = ''.join(lines)
1352
+ return textwrap.dedent(src)
1353
+
1354
+ def _transform_ast(self, src):
1355
+ # src is like:
1356
+ # 1: def foo(...):
1357
+ # 2: ...
1358
+ parsed_ast = ast.parse(src)
1359
+ transformed_ast = self.ast_transformer.visit(parsed_ast)
1360
+ ast.fix_missing_locations(transformed_ast)
1361
+ inc_lineno = self.def_file_lineno - 1
1362
+ ast.increment_lineno(transformed_ast, inc_lineno)
1363
+ return transformed_ast
1364
+
1365
+ def _compile_and_exec(self, transformed_ast):
1366
+ compiled_code = compile(transformed_ast, filename=self.filename, mode='exec')
1367
+ local_namespace = {**self.kwargs}
1368
+ fn_globals = self.fn.__globals__
1369
+ for key, value in globals().items():
1370
+ if key not in fn_globals:
1371
+ fn_globals[key] = value
1372
+ exec(compiled_code, fn_globals, local_namespace)
1373
+ return local_namespace[self.fn.__name__]
1374
+
1375
+
1376
+ class InterpretedFunction:
1377
+ # Cache all rewritten functions
1378
+ rewritten_fn: Dict[Callable, Callable] = {}
1379
+
1380
+ def __init__(self, fn, **kwargs) -> None:
1381
+ self.fn = fn
1382
+ self.rewriter = FunctionRewriter(fn, **kwargs)
1383
+ self.kwargs = kwargs
1384
+
1385
+ def run(*args, **kwargs):
1386
+ grid = kwargs["grid"]
1387
+ fn = self.rewrite()
1388
+ return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)
1389
+
1390
+ self.run = run
1391
+ signature = inspect.signature(fn)
1392
+ self.arg_names = [v.name for v in signature.parameters.values()]
1393
+
1394
+ def rewrite(self):
1395
+ if self.fn not in self.rewritten_fn:
1396
+ self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast()
1397
+ return self.rewritten_fn[self.fn]
1398
+
1399
+ @property
1400
+ def __name__(self):
1401
+ return self.fn.__name__
1402
+
1403
+ def __getitem__(self, grid):
1404
+ fn = self.rewrite()
1405
+ return GridExecutor(fn, self.arg_names, grid)
1406
+
1407
+ def __call__(self, *args, **kwargs):
1408
+ # This is a device function call
1409
+ _patch_lang(self.fn)
1410
+ fn = self.rewrite()
1411
+ try:
1412
+ return fn(*args, **kwargs)
1413
+ except Exception as e:
1414
+ raise InterpreterError(repr(e)) from e