triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1939 @@
1
+ from __future__ import annotations # remove after python 3.11
2
+ import warnings
3
+
4
+ from typing import List, Optional, Sequence, Tuple, TypeVar, Generic, Type
5
+ import numbers
6
+
7
+ from triton.runtime import driver
8
+
9
+ from .._C.libtriton import ir
10
+ from . import core as tl
11
+
12
+ T = TypeVar('T')
13
+ TensorTy = TypeVar('TensorTy')
14
+
15
+
16
+ class IncompatibleTypeErrorImpl(Exception):
17
+
18
+ def __init__(self, type_a, type_b):
19
+ self.type_a = type_a
20
+ self.type_b = type_b
21
+ self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
22
+ super(IncompatibleTypeErrorImpl, self).__init__(self.message)
23
+
24
+
25
+ class TritonSemantic(Generic[TensorTy]):
26
+ tensor: Type[TensorTy] = tl.tensor
27
+ lang = tl
28
+
29
+ builder: ir.builder
30
+
31
+ def __init__(self, builder):
32
+ self.builder = builder
33
+
34
+ # ===----------------------------------------------------------------------===##
35
+ # Programming Model
36
+ # ===----------------------------------------------------------------------===##
37
+
38
+ def program_id(self, axis: int) -> TensorTy:
39
+ if axis not in (0, 1, 2):
40
+ raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
41
+ return self.tensor(self.builder.create_get_program_id(axis), tl.int32)
42
+
43
+ def num_programs(self, axis: int) -> TensorTy:
44
+ if axis not in (0, 1, 2):
45
+ raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
46
+ return self.tensor(self.builder.create_get_num_programs(axis), tl.int32)
47
+
48
+ # ===----------------------------------------------------------------------===//
49
+ # Implicit Casting Utilities
50
+ # ===----------------------------------------------------------------------===//
51
+
52
+ def integer_promote_impl(self, a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
53
+ a_rank = a_ty.int_bitwidth
54
+ b_rank = b_ty.int_bitwidth
55
+ a_sn = a_ty.int_signedness
56
+ b_sn = b_ty.int_signedness
57
+ # Rules for signedness taken from "Usual arithmetic conversions" on
58
+ # https://en.cppreference.com/w/c/language/conversion.
59
+ if a_sn == b_sn:
60
+ return a_ty if a_rank > b_rank else b_ty
61
+ elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
62
+ return a_ty if a_rank >= b_rank else b_ty
63
+ elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
64
+ return b_ty if b_rank >= a_rank else a_ty
65
+ raise TypeError(f"unexpected signedness {a_sn} and {b_sn}")
66
+
67
+ def computation_type_impl(self, a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool,
68
+ div_or_mod: bool) -> tl.dtype:
69
+ # 0) For scalars we follow semantics similar to PyTorch, namely:
70
+ # - If the scalar is of a lower or equal kind (bool < uint < int < fp),
71
+ # it doesn't participate in the promotion
72
+ if a_is_scalar != b_is_scalar:
73
+ scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
74
+ if scalar_ty.kind().value <= tensor_ty.kind().value:
75
+ # Upcast because of 3) and 4) below!
76
+ if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)):
77
+ return tl.float32
78
+ return tensor_ty
79
+
80
+ # 1) if one operand is double, the other is implicitly
81
+ # converted to double
82
+ if a_ty.is_fp64() or b_ty.is_fp64():
83
+ return tl.float64
84
+ # 2) if one operand is float, the other is implicitly
85
+ # converted to float
86
+ if a_ty.is_fp32() or b_ty.is_fp32():
87
+ return tl.float32
88
+ # 3 ) if one operand is half, the other is implicitly converted to half
89
+ # unless we're doing / or %, which do not exist natively in PTX for fp16.
90
+ # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
91
+ if a_ty.is_fp16() or b_ty.is_fp16():
92
+ if div_or_mod:
93
+ return tl.float32
94
+ else:
95
+ return tl.float16
96
+ # 4) return bf16 only if both operands are of bf16
97
+ if a_ty.is_bf16() and b_ty.is_bf16():
98
+ if div_or_mod:
99
+ return tl.float32
100
+ else:
101
+ return tl.bfloat16
102
+ if a_ty.is_bf16() or b_ty.is_bf16():
103
+ return tl.float32
104
+ # 5) return fp16 if operands are different fp8
105
+ if a_ty.is_fp8() and b_ty.is_fp8():
106
+ return a_ty if a_ty == b_ty else tl.float16
107
+ if not a_ty.is_int() or not b_ty.is_int():
108
+ raise TypeError(f"unexpected type {a_ty} and {b_ty}")
109
+ # 6 ) both operands are integer and undergo
110
+ # integer promotion
111
+ if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
112
+ raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
113
+ " because they have different signedness;"
114
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
115
+ return self.integer_promote_impl(a_ty, b_ty)
116
+
117
+ def to_tensor(self, x, check_type: bool = True):
118
+ if isinstance(x, bool):
119
+ return self.tensor(self.builder.get_int1(x), tl.int1)
120
+ # Note: compile-time const integers are represented by unsigned values
121
+ elif isinstance(x, int):
122
+ if -2**31 <= x < 2**31:
123
+ dtype = tl.int32
124
+ elif 2**31 <= x < 2**32:
125
+ dtype = tl.uint32
126
+ elif -2**63 <= x < 2**63:
127
+ dtype = tl.int64
128
+ elif 2**63 <= x < 2**64:
129
+ dtype = tl.uint64
130
+ else:
131
+ raise ValueError(f'Nonrepresentable integer {x}.')
132
+ return self.scalar_constant(x, dtype=dtype)
133
+ elif isinstance(x, float):
134
+ min_float32 = 2**-126
135
+ max_float32 = (2 - 2**-23) * 2**127
136
+ abs_x = __builtins__['abs'](x)
137
+ if abs_x == float("inf") or\
138
+ abs_x == 0.0 or \
139
+ x != x or \
140
+ min_float32 <= abs_x <= max_float32:
141
+ dtype = tl.float32
142
+ else:
143
+ dtype = tl.float64
144
+ return self.scalar_constant(x, dtype=dtype)
145
+
146
+ elif isinstance(x, tl.constexpr):
147
+ return self.to_tensor(x.value)
148
+ elif isinstance(x, self.tensor):
149
+ return x
150
+ if check_type:
151
+ raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
152
+ return x
153
+
154
+ # ===----------------------------------------------------------------------===//
155
+ # Binary Operators
156
+ # ===----------------------------------------------------------------------===//
157
+
158
+ def check_ptr_type_impl(self, type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
159
+ if type_a.is_ptr():
160
+ if not allow_ptr_a:
161
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
162
+ # T* + U* with T != U
163
+ if type_b.is_ptr() and (type_a != type_b):
164
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
165
+ # T* + float
166
+ if type_b.is_floating():
167
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
168
+
169
+ def binary_op_type_checking_impl(self, lhs: TensorTy | numbers.Number, rhs: TensorTy | numbers.Number,
170
+ allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True,
171
+ div_or_mod=False) -> Tuple[TensorTy, TensorTy]:
172
+ lhs_is_scalar = isinstance(lhs, numbers.Number)
173
+ rhs_is_scalar = isinstance(rhs, numbers.Number)
174
+ if lhs_is_scalar:
175
+ lhs_scalar = lhs
176
+ lhs = self.to_tensor(lhs)
177
+ if rhs_is_scalar:
178
+ rhs_scalar = rhs
179
+ rhs = self.to_tensor(rhs)
180
+
181
+ # implicit typecasting
182
+ lhs_sca_ty = lhs.type.scalar
183
+ rhs_sca_ty = rhs.type.scalar
184
+ self.check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
185
+ self.check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
186
+ if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
187
+ ret_sca_ty = self.computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod)
188
+ if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned()
189
+ or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
190
+ raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
191
+ "Perform a explicit cast on one of them.")
192
+ if ret_sca_ty.is_int():
193
+ if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <=
194
+ ret_sca_ty.get_int_max_value()):
195
+ raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
196
+ if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <=
197
+ ret_sca_ty.get_int_max_value()):
198
+ raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
199
+ lhs = self.scalar_constant(lhs_scalar, dtype=ret_sca_ty) if lhs_is_scalar else self.cast(lhs, ret_sca_ty)
200
+ rhs = self.scalar_constant(rhs_scalar, dtype=ret_sca_ty) if rhs_is_scalar else self.cast(rhs, ret_sca_ty)
201
+
202
+ # implicit broadcasting
203
+ lhs, rhs = self.broadcast_impl_value(lhs, rhs)
204
+ return lhs, rhs
205
+
206
+ def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_op: callable):
207
+ if lhs.type.scalar.int_bitwidth >= 64 or not self.builder.options.sanitize_overflow:
208
+ return
209
+ lhs_sca_ty = lhs.type.scalar
210
+ rhs_sca_ty = rhs.type.scalar
211
+ assert lhs_sca_ty == rhs_sca_ty
212
+ assert lhs_sca_ty.is_int()
213
+ lhs = self.cast(lhs, tl.int64)
214
+ rhs = self.cast(rhs, tl.int64)
215
+ ret = binary_op(lhs, rhs, False)
216
+ max_value = lhs_sca_ty.get_int_max_value()
217
+ max_value = self.scalar_constant(max_value, tl.int64)
218
+ min_value = lhs_sca_ty.get_int_min_value()
219
+ min_value = self.scalar_constant(min_value, tl.int64)
220
+ cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value))
221
+ msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
222
+ self.device_assert(cond, msg, None)
223
+
224
+ def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
225
+ sanitize_overflow: bool) -> TensorTy:
226
+ input, other = self.binary_op_type_checking_impl(input, other, True, True)
227
+ input_scalar_ty = input.type.scalar
228
+ other_scalar_ty = other.type.scalar
229
+ if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
230
+ raise TypeError("cannot add pointers together")
231
+
232
+ # offset + ptr
233
+ # ptr + offset
234
+ if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
235
+ input, other = other, input
236
+ input_scalar_ty = input.type.scalar
237
+ other_scalar_ty = other.type.scalar
238
+ if input_scalar_ty.is_ptr():
239
+ other_handle = other.handle
240
+ if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64:
241
+ # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive
242
+ i64_ty = other.type.with_element_ty(tl.int64).to_ir(self.builder)
243
+ other_handle = self.builder.create_int_cast(other.handle, i64_ty, False)
244
+ return self.tensor(self.builder.create_addptr(input.handle, other_handle), input.type)
245
+ # float + float
246
+ elif input_scalar_ty.is_floating():
247
+ return self.tensor(self.builder.create_fadd(input.handle, other.handle), input.type)
248
+ # int + int
249
+ elif input_scalar_ty.is_int():
250
+ if sanitize_overflow:
251
+ self.binary_op_sanitize_overflow_impl(input, other, self.add)
252
+ return self.tensor(self.builder.create_add(input.handle, other.handle), input.type)
253
+ raise TypeError(f"unexpected type {input_scalar_ty}")
254
+
255
+ def sub(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
256
+ sanitize_overflow: bool) -> TensorTy:
257
+ input, other = self.binary_op_type_checking_impl(input, other, True, False)
258
+ scalar_ty = input.type.scalar
259
+ # ptr - offset
260
+ if scalar_ty.is_ptr():
261
+ return self.add(input, self.minus(other), sanitize_overflow=False)
262
+ # float - float
263
+ if scalar_ty.is_floating():
264
+ return self.tensor(self.builder.create_fsub(input.handle, other.handle), input.type)
265
+ # int - int
266
+ elif scalar_ty.is_int():
267
+ if sanitize_overflow:
268
+ self.binary_op_sanitize_overflow_impl(input, other, self.sub)
269
+ return self.tensor(self.builder.create_sub(input.handle, other.handle), input.type)
270
+ raise TypeError(f"unexpected type {scalar_ty}")
271
+
272
+ def mul(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
273
+ sanitize_overflow: bool) -> TensorTy:
274
+ input, other = self.binary_op_type_checking_impl(input, other)
275
+ scalar_ty = input.type.scalar
276
+ # float * float
277
+ if scalar_ty.is_floating():
278
+ return self.tensor(self.builder.create_fmul(input.handle, other.handle), input.type)
279
+ # int * int
280
+ elif scalar_ty.is_int():
281
+ if sanitize_overflow:
282
+ self.binary_op_sanitize_overflow_impl(input, other, self.mul)
283
+ return self.tensor(self.builder.create_mul(input.handle, other.handle), input.type)
284
+ raise TypeError(f"unexpected type {scalar_ty}")
285
+
286
+ def truediv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
287
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
288
+ input_scalar_ty = input.type.scalar
289
+ other_scalar_ty = other.type.scalar
290
+ # float / int
291
+ if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
292
+ other = self.cast(other, input_scalar_ty)
293
+ # int / float
294
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
295
+ input = self.cast(input, other_scalar_ty)
296
+ # int / int (cast to tl.float32)
297
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
298
+ input = self.cast(input, tl.float32)
299
+ other = self.cast(other, tl.float32)
300
+ # float / float (cast to the highest exponent type)
301
+ elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
302
+ if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
303
+ other = self.cast(other, input_scalar_ty)
304
+ else:
305
+ input = self.cast(input, other_scalar_ty)
306
+ # unreachable
307
+ else:
308
+ raise TypeError(f"unexpected type {input_scalar_ty}")
309
+ return self.tensor(self.builder.create_fdiv(input.handle, other.handle), input.type)
310
+
311
+ def floordiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
312
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
313
+ input_scalar_ty = input.type.scalar
314
+ other_scalar_ty = other.type.scalar
315
+ if input_scalar_ty.is_int() and other_scalar_ty.is_int():
316
+ ret_ty = self.integer_promote_impl(input_scalar_ty, other_scalar_ty)
317
+ input = self.cast(input, ret_ty)
318
+ other = self.cast(other, ret_ty)
319
+ if ret_ty.is_int_signed():
320
+ return self.tensor(self.builder.create_sdiv(input.handle, other.handle), input.type)
321
+ else:
322
+ return self.tensor(self.builder.create_udiv(input.handle, other.handle), input.type)
323
+ raise TypeError(f"unexpected type {input_scalar_ty}")
324
+
325
+ def fdiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, ieee_rounding: bool) -> TensorTy:
326
+ input_scalar_ty = input.type.scalar
327
+ other_scalar_ty = other.type.scalar
328
+ if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
329
+ raise TypeError("both operands of fdiv must have floating scalar type")
330
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, False, True)
331
+ ret = self.builder.create_fdiv(input.handle, other.handle)
332
+ return self.tensor(ret, input.type)
333
+
334
+ def mod(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
335
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
336
+ scalar_ty = input.type.scalar
337
+ other_scalar_ty = other.type.scalar
338
+ # float % float
339
+ if scalar_ty.is_floating():
340
+ return self.tensor(self.builder.create_frem(input.handle, other.handle), input.type)
341
+ # % int
342
+ elif scalar_ty.is_int():
343
+ if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
344
+ raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
345
+ "because they have different signedness;"
346
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
347
+ if scalar_ty.is_int_signed():
348
+ return self.tensor(self.builder.create_srem(input.handle, other.handle), input.type)
349
+ else:
350
+ return self.tensor(self.builder.create_urem(input.handle, other.handle), input.type)
351
+ raise TypeError(f"unexpected type {scalar_ty}")
352
+
353
+ ##############
354
+ # other arithmetic ops
355
+ ##############
356
+
357
+ def minimum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
358
+ x, y = self.binary_op_type_checking_impl(x, y)
359
+ dtype = x.dtype
360
+ if dtype.is_floating():
361
+ if propagate_nan == tl.PropagateNan.ALL:
362
+ return self.tensor(self.builder.create_minimumf(x.handle, y.handle), x.type)
363
+ elif propagate_nan == tl.PropagateNan.NONE:
364
+ return self.tensor(self.builder.create_minnumf(x.handle, y.handle), x.type)
365
+ else:
366
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
367
+ elif dtype.is_int_signed():
368
+ return self.tensor(self.builder.create_minsi(x.handle, y.handle), x.type)
369
+ elif dtype.is_int_unsigned():
370
+ return self.tensor(self.builder.create_minui(x.handle, y.handle), x.type)
371
+ else:
372
+ raise TypeError(f"Unexpected dtype {dtype}")
373
+
374
+ def maximum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
375
+ x, y = self.binary_op_type_checking_impl(x, y)
376
+ dtype = x.dtype
377
+ if dtype.is_floating():
378
+ if propagate_nan == tl.PropagateNan.ALL:
379
+ return self.tensor(self.builder.create_maximumf(x.handle, y.handle), x.type)
380
+ elif propagate_nan == tl.PropagateNan.NONE:
381
+ return self.tensor(self.builder.create_maxnumf(x.handle, y.handle), x.type)
382
+ else:
383
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
384
+ elif dtype.is_int_signed():
385
+ return self.tensor(self.builder.create_maxsi(x.handle, y.handle), x.type)
386
+ elif dtype.is_int_unsigned():
387
+ return self.tensor(self.builder.create_maxui(x.handle, y.handle), x.type)
388
+ else:
389
+ raise TypeError(f"Unexpected dtype {dtype}")
390
+
391
+ def clamp(self, x: TensorTy, min: TensorTy, max: TensorTy, propagate_nan: tl.PropagateNan):
392
+ min, max = self.binary_op_type_checking_impl(min, max)
393
+ x, min = self.binary_op_type_checking_impl(x, min)
394
+ x, max = self.binary_op_type_checking_impl(x, max)
395
+
396
+ dtype = x.dtype
397
+ if dtype.is_floating():
398
+ return self.tensor(self.builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
399
+ else:
400
+ raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
401
+
402
+ ##############
403
+ # bitwise ops
404
+ ##############
405
+
406
+ def bitwise_op_type_checking_impl(self, input: TensorTy, other: TensorTy) -> Tuple[TensorTy, TensorTy]:
407
+ input, other = self.binary_op_type_checking_impl(input, other)
408
+ input_sca_ty = input.type.scalar
409
+ other_sca_ty = other.type.scalar
410
+ if not input_sca_ty.is_int() or not other_sca_ty.is_int():
411
+ raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
412
+ ret_sca_ty = self.integer_promote_impl(input_sca_ty, other_sca_ty)
413
+ if ret_sca_ty != input_sca_ty:
414
+ input = self.cast(input, ret_sca_ty)
415
+ if ret_sca_ty != other_sca_ty:
416
+ other = self.cast(other, ret_sca_ty)
417
+ return input, other
418
+
419
+ def and_(self, input: TensorTy, other: TensorTy) -> TensorTy:
420
+ input, other = self.bitwise_op_type_checking_impl(input, other)
421
+ return self.tensor(self.builder.create_and(input.handle, other.handle), input.type)
422
+
423
+ def or_(self, input: TensorTy, other: TensorTy) -> TensorTy:
424
+ input, other = self.bitwise_op_type_checking_impl(input, other)
425
+ return self.tensor(self.builder.create_or(input.handle, other.handle), input.type)
426
+
427
+ def xor_(self, input: TensorTy, other: TensorTy) -> TensorTy:
428
+ input, other = self.bitwise_op_type_checking_impl(input, other)
429
+ return self.tensor(self.builder.create_xor(input.handle, other.handle), input.type)
430
+
431
+ def logical_and(self, input: TensorTy, other: TensorTy) -> TensorTy:
432
+ if not input.type.is_int1():
433
+ input = self.bitcast(input, tl.int1)
434
+ if not other.type.is_int1():
435
+ other = self.bitcast(other, tl.int1)
436
+ return self.and_(input, other)
437
+
438
+ def logical_or(self, input: TensorTy, other: TensorTy) -> TensorTy:
439
+ if not input.type.is_int1():
440
+ input = self.bitcast(input, tl.int1)
441
+ if not other.type.is_int1():
442
+ other = self.bitcast(other, tl.int1)
443
+ return self.or_(input, other)
444
+
445
+ def not_(self, input: TensorTy):
446
+ if not input.type.is_int1():
447
+ input = self.bitcast(input, tl.int1)
448
+ return self.invert(input)
449
+
450
+ def lshr(self, input: TensorTy, other: TensorTy) -> TensorTy:
451
+ input, other = self.bitwise_op_type_checking_impl(input, other)
452
+ return self.tensor(self.builder.create_lshr(input.handle, other.handle), input.type)
453
+
454
+ def ashr(self, input: TensorTy, other: TensorTy) -> TensorTy:
455
+ input, other = self.bitwise_op_type_checking_impl(input, other)
456
+ return self.tensor(self.builder.create_ashr(input.handle, other.handle), input.type)
457
+
458
+ def shl(self, input: TensorTy, other: TensorTy) -> TensorTy:
459
+ input, other = self.bitwise_op_type_checking_impl(input, other)
460
+ return self.tensor(self.builder.create_shl(input.handle, other.handle), input.type)
461
+
462
+ # ===----------------------------------------------------------------------===//
463
+ # Unary Operators
464
+ # ===----------------------------------------------------------------------===//
465
+
466
+ def plus(self, input: TensorTy) -> TensorTy:
467
+ return input
468
+
469
+ def minus(self, input: TensorTy) -> TensorTy:
470
+ input_sca_ty = input.type.scalar
471
+ if input_sca_ty.is_ptr():
472
+ raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
473
+ _0 = self.tensor(self.builder.get_null_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
474
+ return self.sub(_0, input, True)
475
+
476
+ def invert(self, input: TensorTy) -> TensorTy:
477
+ input_sca_ty = input.type.scalar
478
+ if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
479
+ raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
480
+ _1 = self.tensor(self.builder.get_all_ones_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
481
+ return self.xor_(input, _1)
482
+
483
+ # ===----------------------------------------------------------------------===//
484
+ # Comparison Operators
485
+ # ===----------------------------------------------------------------------===//
486
+
487
+ def _bool_like(self, v: TensorTy) -> tl.block_type:
488
+ return v.type.with_element_ty(tl.int1)
489
+
490
+ def greater_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
491
+ input, other = self.binary_op_type_checking_impl(input, other)
492
+ scalar_ty = input.type.scalar
493
+ # float > float
494
+ if scalar_ty.is_floating():
495
+ return self.tensor(self.builder.create_fcmpOGT(input.handle, other.handle), self._bool_like(input))
496
+ # > int
497
+ elif scalar_ty.is_int():
498
+ if scalar_ty.is_int_signed():
499
+ return self.tensor(self.builder.create_icmpSGT(input.handle, other.handle), self._bool_like(input))
500
+ else:
501
+ return self.tensor(self.builder.create_icmpUGT(input.handle, other.handle), self._bool_like(input))
502
+ raise TypeError(f"unexpected type {scalar_ty}")
503
+
504
+ def greater_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
505
+ input, other = self.binary_op_type_checking_impl(input, other)
506
+ scalar_ty = input.type.scalar
507
+ # float >= float
508
+ if scalar_ty.is_floating():
509
+ return self.tensor(self.builder.create_fcmpOGE(input.handle, other.handle), self._bool_like(input))
510
+ # >= int
511
+ elif scalar_ty.is_int():
512
+ if scalar_ty.is_int_signed():
513
+ return self.tensor(self.builder.create_icmpSGE(input.handle, other.handle), self._bool_like(input))
514
+ else:
515
+ return self.tensor(self.builder.create_icmpUGE(input.handle, other.handle), self._bool_like(input))
516
+ raise TypeError(f"unexpected type {scalar_ty}")
517
+
518
+ def less_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
519
+ input, other = self.binary_op_type_checking_impl(input, other)
520
+ scalar_ty = input.type.scalar
521
+ # float < float
522
+ if scalar_ty.is_floating():
523
+ return self.tensor(self.builder.create_fcmpOLT(input.handle, other.handle), self._bool_like(input))
524
+ # < int
525
+ elif scalar_ty.is_int():
526
+ if scalar_ty.is_int_signed():
527
+ return self.tensor(self.builder.create_icmpSLT(input.handle, other.handle), self._bool_like(input))
528
+ else:
529
+ return self.tensor(self.builder.create_icmpULT(input.handle, other.handle), self._bool_like(input))
530
+ raise TypeError(f"unexpected type {scalar_ty}")
531
+
532
+ def less_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
533
+ input, other = self.binary_op_type_checking_impl(input, other)
534
+ scalar_ty = input.type.scalar
535
+ # float < float
536
+ if scalar_ty.is_floating():
537
+ return self.tensor(self.builder.create_fcmpOLE(input.handle, other.handle), self._bool_like(input))
538
+ # < int
539
+ elif scalar_ty.is_int():
540
+ if scalar_ty.is_int_signed():
541
+ return self.tensor(self.builder.create_icmpSLE(input.handle, other.handle), self._bool_like(input))
542
+ else:
543
+ return self.tensor(self.builder.create_icmpULE(input.handle, other.handle), self._bool_like(input))
544
+ raise TypeError(f"unexpected type {scalar_ty}")
545
+
546
+ def equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
547
+ input, other = self.binary_op_type_checking_impl(input, other)
548
+ scalar_ty = input.type.scalar
549
+ # float == float
550
+ if scalar_ty.is_floating():
551
+ return self.tensor(self.builder.create_fcmpOEQ(input.handle, other.handle), self._bool_like(input))
552
+ # == int
553
+ elif scalar_ty.is_int():
554
+ return self.tensor(self.builder.create_icmpEQ(input.handle, other.handle), self._bool_like(input))
555
+ raise TypeError(f"unexpected type {scalar_ty}")
556
+
557
+ def not_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
558
+ input, other = self.binary_op_type_checking_impl(input, other)
559
+ scalar_ty = input.type.scalar
560
+ # float == float
561
+ if scalar_ty.is_floating():
562
+ return self.tensor(self.builder.create_fcmpUNE(input.handle, other.handle), self._bool_like(input))
563
+ # == int
564
+ elif scalar_ty.is_int():
565
+ return self.tensor(self.builder.create_icmpNE(input.handle, other.handle), self._bool_like(input))
566
+ raise TypeError(f"unexpected type {scalar_ty}")
567
+
568
+ # ===----------------------------------------------------------------------===//
569
+ # Block Creation
570
+ # ===----------------------------------------------------------------------===//
571
+
572
+ def arange(self, start: int, end: int, *, ret_ty: tl.block_type = None) -> TensorTy:
573
+ if not isinstance(start, int) or not isinstance(end, int):
574
+ raise ValueError("arange's arguments must be of type tl.constexpr")
575
+ is_start_int64 = bool(start >> 32)
576
+ is_end_int64 = bool(end >> 32)
577
+ if is_start_int64 or is_end_int64:
578
+ raise ValueError("arange must fit in int32")
579
+ if end <= start:
580
+ raise ValueError("arange's end argument must be greater than the start argument")
581
+ range = end - start
582
+ if (range & (range - 1)) != 0:
583
+ raise ValueError("arange's range must be a power of 2")
584
+ shape = [range]
585
+ if ret_ty is None:
586
+ ret_ty = tl.block_type(tl.int32, shape)
587
+ ret_ty_ir = ret_ty.to_ir(self.builder)
588
+ return self.tensor(self.builder.create_make_range(ret_ty_ir, start, end), ret_ty)
589
+
590
+ def scalar_constant(self, value, dtype: tl.dtype) -> TensorTy:
591
+ # scalar
592
+ if dtype is None:
593
+ raise ValueError("dtype must be specified when value is not a tensor")
594
+ if value == 0:
595
+ value = self.builder.get_null_value(dtype.to_ir(self.builder))
596
+ else:
597
+ get_value_fn = getattr(self.builder, f"get_{dtype.name}")
598
+ value = get_value_fn(value)
599
+ return self.tensor(value, dtype)
600
+
601
+ def make_scalar(self, value, dtype: tl.dtype) -> TensorTy:
602
+ if isinstance(value, tl.tensor):
603
+ assert value.numel.value == 1, "only accepts size-1 tensor"
604
+ return self.cast(value, dtype)
605
+ # scalar
606
+ return self.scalar_constant(value, dtype)
607
+
608
+ def full(self, shape: List[int], value, dtype: tl.dtype) -> TensorTy:
609
+ return self.splat(self.make_scalar(value, dtype), shape)
610
+
611
+ # ===----------------------------------------------------------------------===//
612
+ # Shape Manipulation
613
+ # ===----------------------------------------------------------------------===//
614
+
615
+ def splat(self, value: TensorTy, shape: List[int]) -> TensorTy:
616
+ assert not value.type.is_block(), "Cannot splat a block tensor"
617
+ if len(shape) == 0:
618
+ return value
619
+ ret_ty = tl.block_type(value.dtype, shape)
620
+ return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty)
621
+
622
+ def unsplat(self, value: TensorTy) -> TensorTy:
623
+ return self.tensor(self.builder.create_unsplat(value.handle), value.dtype)
624
+
625
+ def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy:
626
+ numel = 1
627
+ for s in dst_shape:
628
+ numel *= s
629
+ if input.type.numel != numel:
630
+ raise ValueError("reshape() cannot change total number of elements in tensor")
631
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
632
+ return self.tensor(self.builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
633
+
634
+ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
635
+ dst_shape = [tl._unwrap_if_constexpr(x) for x in input.shape]
636
+ dst_shape.insert(axis, 1)
637
+
638
+ if not input.type.is_block():
639
+ return self.splat(input, shape=dst_shape)
640
+
641
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
642
+ return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty)
643
+
644
+ def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy:
645
+ assert can_reorder, "current implementation of `cat` always may reorder elements"
646
+ assert len(lhs.shape) == 1
647
+ ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
648
+ return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type)
649
+
650
+ def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
651
+ a, b = self.broadcast_impl_value(a, b)
652
+
653
+ # The IR can't handle joining two scalars, so upcast them to 1D tensors,
654
+ # then downcast the result.
655
+ was_rank_1 = a.shape == []
656
+ if was_rank_1:
657
+ a = self.expand_dims(a, 0)
658
+ b = self.expand_dims(b, 0)
659
+
660
+ if isinstance(a.shape[-1], tl.constexpr):
661
+ two = tl.constexpr(2)
662
+ else:
663
+ two = 2
664
+ new_shape = a.shape + [two]
665
+
666
+ ret_type = tl.block_type(a.type.scalar, new_shape)
667
+ ret = self.tensor(self.builder.create_join(a.handle, b.handle), ret_type)
668
+
669
+ if was_rank_1:
670
+ ret = self.reshape(ret, [2], can_reorder=False)
671
+
672
+ return ret
673
+
674
+ def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
675
+ assert (len(a.shape) > 0)
676
+ assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2)
677
+
678
+ new_shape = a.shape[:-1]
679
+ ret_type = tl.block_type(a.type.scalar, new_shape)
680
+ outLHS, outRHS = self.builder.create_split(a.handle)
681
+ return (
682
+ self.tensor(outLHS, ret_type),
683
+ self.tensor(outRHS, ret_type),
684
+ )
685
+
686
+ def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
687
+ if len(input.shape) != len(dims):
688
+ raise ValueError("permute dims must have the same length as input shape")
689
+ if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))):
690
+ raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
691
+
692
+ ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims])
693
+ return self.tensor(self.builder.create_trans(input.handle, dims), ret_type)
694
+
695
+ def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
696
+ if not input.type.is_block():
697
+ return self.splat(input, shape)
698
+ src_shape = input.type.get_block_shapes()
699
+ if len(src_shape) != len(shape):
700
+ raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
701
+ if shape == src_shape:
702
+ return input
703
+ for i, item in enumerate(src_shape):
704
+ if shape[i] != item and item != 1:
705
+ raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
706
+ f" must match the existing size ({item}) at non-singleton dimension"
707
+ f" {i}: {src_shape}, {shape}")
708
+ ret_ty = tl.block_type(input.type.scalar, shape)
709
+ return self.tensor(self.builder.create_broadcast(input.handle, shape), ret_ty)
710
+
711
+ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
712
+ lhs_ty = lhs.type
713
+ rhs_ty = rhs.type
714
+
715
+ # make_shape_compatible(block, scalar)
716
+ if lhs_ty.is_block() and not rhs_ty.is_block():
717
+ rhs_ty = lhs_ty.with_element_ty(rhs_ty.scalar)
718
+ rhs = self.tensor(self.builder.create_splat(rhs_ty.to_ir(self.builder), rhs.handle), rhs_ty)
719
+ # make_shape_compatible(scalar, block)
720
+ elif not lhs_ty.is_block() and rhs_ty.is_block():
721
+ lhs_ty = rhs_ty.with_element_ty(lhs_ty.scalar)
722
+ lhs = self.tensor(self.builder.create_splat(lhs_ty.to_ir(self.builder), lhs.handle), lhs_ty)
723
+ # make_shape_compatible(block, block)
724
+ elif lhs_ty.is_block() and rhs_ty.is_block():
725
+ lhs_shape = lhs_ty.get_block_shapes()
726
+ rhs_shape = rhs_ty.get_block_shapes()
727
+
728
+ if len(lhs_shape) < len(rhs_shape):
729
+ # Add new axes to lhs
730
+ for _ in range(len(lhs_shape), len(rhs_shape)):
731
+ lhs = self.tensor(self.builder.create_expand_dims(lhs.handle, 0),
732
+ tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
733
+ lhs_ty = lhs.type
734
+ lhs_shape = lhs_ty.get_block_shapes()
735
+ elif len(rhs_shape) < len(lhs_shape):
736
+ # Add new axes to rhs
737
+ for _ in range(len(rhs_shape), len(lhs_shape)):
738
+ rhs = self.tensor(self.builder.create_expand_dims(rhs.handle, 0),
739
+ tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
740
+ rhs_ty = rhs.type
741
+ rhs_shape = rhs_ty.get_block_shapes()
742
+ assert len(rhs_shape) == len(lhs_shape)
743
+
744
+ ret_shape = []
745
+ for i, left in enumerate(lhs_shape):
746
+ right = rhs_shape[i]
747
+ if left == 1:
748
+ ret_shape.append(right)
749
+ elif (right == 1) or (right == left):
750
+ ret_shape.append(left)
751
+ else:
752
+ raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
753
+ "at index " + str(i) + ": " + str(left) + " and " + str(right))
754
+ if lhs_shape != ret_shape:
755
+ ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
756
+ lhs = self.tensor(self.builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
757
+ if rhs_shape != ret_shape:
758
+ ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
759
+ rhs = self.tensor(self.builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
760
+ # (scalar, scalar) => returns original blocks
761
+ return lhs, rhs
762
+
763
+ #######
764
+ # cast
765
+ #######
766
+
767
+ def _str_to_rounding_mode(self, rounding_mode: Optional[str]):
768
+ if rounding_mode is None:
769
+ return None
770
+ if rounding_mode == 'rtne':
771
+ return ir.ROUNDING_MODE.RTNE
772
+ if rounding_mode == 'rtz':
773
+ return ir.ROUNDING_MODE.RTZ
774
+ raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.")
775
+
776
+ def bitcast(self, input: TensorTy, dst_ty: tl.dtype) -> TensorTy:
777
+ src_ty = input.type
778
+ if src_ty.is_block():
779
+ dst_ty = src_ty.with_element_ty(dst_ty.scalar)
780
+ if src_ty == dst_ty:
781
+ return input
782
+ src_sca_ty = src_ty.scalar
783
+ dst_sca_ty = dst_ty.scalar
784
+ if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
785
+ return self.cast(input, dst_ty)
786
+ # Bitcast
787
+ src_bits = src_sca_ty.primitive_bitwidth
788
+ dst_bits = dst_sca_ty.primitive_bitwidth
789
+ if src_bits != dst_bits:
790
+ raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
791
+ "data-type of size " + str(dst_bits))
792
+ return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
793
+
794
+ def cast(self, input: TensorTy, dst_ty: tl.dtype, fp_downcast_rounding: Optional[str] = None) -> TensorTy:
795
+ src_ty = input.type
796
+ src_sca_ty = src_ty.scalar
797
+ dst_sca_ty = dst_ty.scalar
798
+ if src_sca_ty == dst_sca_ty:
799
+ return input
800
+ if src_ty.is_block():
801
+ dst_ty = src_ty.with_element_ty(dst_sca_ty)
802
+
803
+ # For fp downcasting default rounding mode should be RTNE, for all other conversions it should
804
+ # not be set
805
+ fp_downcast_rounding = self._str_to_rounding_mode(fp_downcast_rounding)
806
+ use_custom_rounding = False
807
+ if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
808
+ ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
809
+ if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
810
+ elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
811
+ else:
812
+ if fp_downcast_rounding is not None:
813
+ raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
814
+ "Source scalar type is " + str(src_sca_ty) + " and destination type is " +
815
+ str(dst_sca_ty))
816
+
817
+ if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
818
+ assert self.builder.codegen_fns.get(
819
+ "convert_custom_types") is not None, "target doesn't provide conversion for this type."
820
+ return self.builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _semantic=self)
821
+ # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
822
+ # and non-default rounding modes for downcasting
823
+ if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
824
+ (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
825
+ use_custom_rounding:
826
+ return self.tensor(
827
+ self.builder.create_fp_to_fp(input.handle, dst_ty.to_ir(self.builder), fp_downcast_rounding), dst_ty)
828
+
829
+ # bf16 <=> (not fp32)
830
+ if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
831
+ (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
832
+ return self.cast(self.cast(input, tl.float32), dst_sca_ty)
833
+
834
+ # Standard floating types' casting: truncation
835
+ # fp64 => fp32, fp16, bf16
836
+ # fp32 => fp16, bf16
837
+ truncate_fp = src_sca_ty.is_floating() and \
838
+ dst_sca_ty.is_floating() and \
839
+ src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
840
+ if truncate_fp:
841
+ return self.tensor(self.builder.create_fp_trunc(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
842
+
843
+ # Standard floating types' casting: extension
844
+ # fp32 => fp64
845
+ # fp16 => fp32, fp64
846
+ # bf16 => fp32, fp64
847
+ ext_fp = src_sca_ty.is_floating() and \
848
+ dst_sca_ty.is_floating() and \
849
+ src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
850
+ if ext_fp:
851
+ return self.tensor(self.builder.create_fp_ext(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
852
+
853
+ # Casting between integer types
854
+ if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
855
+ (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
856
+ sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
857
+ if dst_sca_ty.is_bool():
858
+ ty = input.dtype.to_ir(self.builder)
859
+ _0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
860
+ return self.not_equal(input, _0)
861
+ else:
862
+ return self.tensor(self.builder.create_int_cast(input.handle, dst_ty.to_ir(self.builder), sign_extend),
863
+ dst_ty)
864
+
865
+ # Casting standard floating types to integer types
866
+ if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
867
+ if dst_sca_ty.is_bool():
868
+ ty = input.dtype.to_ir(self.builder)
869
+ _0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
870
+ return self.not_equal(input, _0)
871
+ elif dst_sca_ty.is_int_signed():
872
+ return self.tensor(self.builder.create_fp_to_si(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
873
+ else:
874
+ return self.tensor(self.builder.create_fp_to_ui(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
875
+
876
+ # Casting integer types to standard floating types
877
+ if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
878
+ if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
879
+ return self.tensor(self.builder.create_ui_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
880
+ else:
881
+ return self.tensor(self.builder.create_si_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
882
+
883
+ # Casting pointer types to integer types
884
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
885
+ bitwidth = dst_sca_ty.int_bitwidth
886
+ if bitwidth == 64:
887
+ return self.tensor(self.builder.create_ptr_to_int(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
888
+ if bitwidth == 1:
889
+ return self.not_equal(self.cast(input, tl.int64), self.tensor(self.builder.get_int64(0), tl.int64))
890
+
891
+ # Casting integer types to pointer types
892
+ if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
893
+ return self.tensor(self.builder.create_int_to_ptr(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
894
+
895
+ # Casting pointer types to pointer types
896
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
897
+ return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
898
+
899
+ assert False, f'cannot cast {input} to {dst_ty}'
900
+
901
+ # ===----------------------------------------------------------------------===//
902
+ # Memory Operators
903
+ # ===----------------------------------------------------------------------===//
904
+
905
+ def _str_to_load_cache_modifier(self, cache_modifier):
906
+ cache = ir.CACHE_MODIFIER.NONE # default
907
+ if cache_modifier:
908
+ if cache_modifier == ".ca":
909
+ cache = ir.CACHE_MODIFIER.CA
910
+ elif cache_modifier == ".cg":
911
+ cache = ir.CACHE_MODIFIER.CG
912
+ elif cache_modifier == ".cv":
913
+ cache = ir.CACHE_MODIFIER.CV
914
+ else:
915
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
916
+ return cache
917
+
918
+ def _str_to_store_cache_modifier(self, cache_modifier):
919
+ cache = ir.CACHE_MODIFIER.NONE # default
920
+ if cache_modifier:
921
+ if cache_modifier == ".wb":
922
+ cache = ir.CACHE_MODIFIER.WB
923
+ elif cache_modifier == ".cg":
924
+ cache = ir.CACHE_MODIFIER.CG
925
+ elif cache_modifier == ".cs":
926
+ cache = ir.CACHE_MODIFIER.CS
927
+ elif cache_modifier == ".wt":
928
+ cache = ir.CACHE_MODIFIER.WT
929
+ else:
930
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
931
+ return cache
932
+
933
+ def _str_to_eviction_policy(self, eviction_policy):
934
+ eviction = ir.EVICTION_POLICY.NORMAL # default
935
+ if eviction_policy:
936
+ if eviction_policy == "evict_last":
937
+ eviction = ir.EVICTION_POLICY.EVICT_LAST
938
+ elif eviction_policy == "evict_first":
939
+ eviction = ir.EVICTION_POLICY.EVICT_FIRST
940
+ else:
941
+ raise ValueError(f"Eviction policy {eviction_policy} not supported")
942
+ return eviction
943
+
944
+ def _str_to_padding_option(self, padding_option):
945
+ padding = None # default
946
+ if padding_option:
947
+ if padding_option == "zero":
948
+ padding = ir.PADDING_OPTION.PAD_ZERO
949
+ elif padding_option == "nan":
950
+ padding = ir.PADDING_OPTION.PAD_NAN
951
+ else:
952
+ raise ValueError(f"Padding option {padding_option} not supported")
953
+ return padding
954
+
955
+ def _str_to_sem(self, sem_option):
956
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
957
+ if sem_option:
958
+ if sem_option == "acquire":
959
+ sem = ir.MEM_SEMANTIC.ACQUIRE
960
+ elif sem_option == "release":
961
+ sem = ir.MEM_SEMANTIC.RELEASE
962
+ elif sem_option == "acq_rel":
963
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
964
+ elif sem_option == "relaxed":
965
+ sem = ir.MEM_SEMANTIC.RELAXED
966
+ else:
967
+ raise ValueError(f"Memory semantic {sem_option} not supported")
968
+ return sem
969
+
970
+ def _str_to_scope(self, scope_option):
971
+ scope = ir.MEM_SYNC_SCOPE.GPU
972
+ if scope_option:
973
+ if scope_option == "gpu":
974
+ scope = ir.MEM_SYNC_SCOPE.GPU
975
+ elif scope_option == "cta":
976
+ scope = ir.MEM_SYNC_SCOPE.CTA
977
+ elif scope_option == "sys":
978
+ scope = ir.MEM_SYNC_SCOPE.SYSTEM
979
+ else:
980
+ raise ValueError(f"Memory semantic {scope_option} not supported")
981
+ return scope
982
+
983
+ def _canonicalize_boundary_check(self, boundary_check, block_shape):
984
+ if boundary_check:
985
+ if not hasattr(boundary_check, "__iter__"):
986
+ boundary_check = [boundary_check]
987
+ boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
988
+ for dim in boundary_check:
989
+ assert isinstance(dim, int) and 0 <= dim < len(block_shape)
990
+ assert len(boundary_check) > 0
991
+ assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
992
+ return sorted(boundary_check)
993
+ return ()
994
+
995
+ def _load_block_pointer(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
996
+ # Load by a block pointer: `pointer_type<block_type<>>`
997
+ # Block pointer can not have `mask` and `other` arguments
998
+ if mask is not None or other is not None:
999
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1000
+
1001
+ elt_ty = ptr.type.element_ty.element_ty
1002
+ assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1003
+ if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
1004
+ raise ValueError("Padding option `nan` is not supported for integer block pointers")
1005
+
1006
+ # `dst_ty` is de-referenced type of the pointer type
1007
+ dst_ty = ptr.type.element_ty
1008
+
1009
+ # Check `boundary_check` argument
1010
+ boundary_check = self._canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
1011
+
1012
+ # Build IR
1013
+ return self.tensor(
1014
+ self.builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile),
1015
+ dst_ty)
1016
+
1017
+ def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
1018
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1019
+ if not ptr.type.scalar.is_ptr():
1020
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
1021
+
1022
+ # Check `mask`, `other`, `boundary_check`, and `padding` arguments
1023
+ if mask is None and other is not None:
1024
+ raise ValueError("`other` cannot be provided without `mask`")
1025
+ if padding or boundary_check:
1026
+ raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
1027
+ "pointers or loading a scalar. Because the compiler does not know the boundary; please "
1028
+ "use block pointers (defined by `make_block_ptr`) instead")
1029
+
1030
+ # For a pointer of scalar, check the type of `mask` and `other`
1031
+ if not ptr.type.is_block():
1032
+ if mask and mask.type.is_block():
1033
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1034
+ if other and other.type.is_block():
1035
+ raise ValueError("Other argument cannot be block type if pointer argument is not a block")
1036
+
1037
+ # Make `mask` and `other` into the same shape as `ptr`
1038
+ if ptr.type.is_block():
1039
+ if mask is not None:
1040
+ ptr, mask = self.broadcast_impl_value(ptr, mask)
1041
+ if other is not None:
1042
+ ptr, other = self.broadcast_impl_value(ptr, other)
1043
+
1044
+ # Get `pointer_type<elt_ty>` and `elt_ty`
1045
+ ptr_ty = ptr.type.scalar
1046
+ elt_ty = ptr_ty.element_ty
1047
+
1048
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1049
+ is_bool = elt_ty == tl.int1
1050
+ if is_bool:
1051
+ elt_ty = tl.int8
1052
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1053
+ ptr = self.cast(ptr, ptr_ty)
1054
+
1055
+ # Cast `other` into `elt_ty` type
1056
+ if other is not None:
1057
+ other = self.cast(other, elt_ty)
1058
+
1059
+ # Create loaded result type `dst_ty`
1060
+ if ptr.type.is_block():
1061
+ dst_ty = ptr.type.with_element_ty(elt_ty)
1062
+ else:
1063
+ # Load by de-referencing the pointer of scalar
1064
+ dst_ty = elt_ty
1065
+
1066
+ # Build IR
1067
+ if mask is None:
1068
+ ret = self.tensor(self.builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
1069
+ else:
1070
+ ret = self.tensor(
1071
+ self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
1072
+ eviction, is_volatile), dst_ty)
1073
+ if is_bool:
1074
+ ret = self.cast(ret, tl.int1)
1075
+ return ret
1076
+
1077
+ def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy], boundary_check: Tuple,
1078
+ padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool) -> TensorTy:
1079
+ # Cache, eviction and padding options
1080
+ cache = self._str_to_load_cache_modifier(cache_modifier)
1081
+ eviction = self._str_to_eviction_policy(eviction_policy)
1082
+ padding = self._str_to_padding_option(padding_option)
1083
+
1084
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1085
+ # Load by a block pointer: `pointer_type<block_type<>>`
1086
+ return self._load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
1087
+ else:
1088
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1089
+ return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
1090
+
1091
+ def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str,
1092
+ eviction_policy: str) -> TensorTy:
1093
+ assert isinstance(desc, tl.tensor_descriptor_base)
1094
+ ndim = len(desc.block_shape)
1095
+ assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1096
+
1097
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1098
+ x = self.builder.create_descriptor_load(desc.handle, offsets, self._str_to_load_cache_modifier(cache_modifier),
1099
+ self._str_to_eviction_policy(eviction_policy))
1100
+ return self.tensor(x, desc.block_type)
1101
+
1102
+ def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None:
1103
+ assert isinstance(desc, tl.tensor_descriptor_base)
1104
+ ndim = len(desc.block_shape)
1105
+ assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1106
+ assert value.shape == desc.block_shape
1107
+
1108
+ def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1109
+ self.validate_store_like(desc, value, offsets)
1110
+ # implicitly cast to the descriptor's type
1111
+ value = self.cast(value, desc.dtype)
1112
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1113
+ return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
1114
+
1115
+ def descriptor_atomic_add(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1116
+ self.validate_store_like(desc, value, offsets)
1117
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, "Unsupported dtype"
1118
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1119
+ kind = ir.DESCRIPTOR_REDUCE_KIND.ADD
1120
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1121
+
1122
+ def _has_native_tma(self, ):
1123
+ target = driver.active.get_current_target()
1124
+ return (target.backend == "cuda" and target.arch >= 90)
1125
+
1126
+ def _descriptor_atomic_min_max_supported(self, dtype):
1127
+ assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype"
1128
+ if dtype in {tl.float16, tl.bfloat16}:
1129
+ assert self._has_native_tma(), "16-bit float types require native tma support"
1130
+
1131
+ def descriptor_atomic_min(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1132
+ self.validate_store_like(desc, value, offsets)
1133
+ self._descriptor_atomic_min_max_supported(desc.dtype)
1134
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1135
+ kind = ir.DESCRIPTOR_REDUCE_KIND.MIN
1136
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1137
+
1138
+ def descriptor_atomic_max(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1139
+ self.validate_store_like(desc, value, offsets)
1140
+ self._descriptor_atomic_min_max_supported(desc.dtype)
1141
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1142
+ kind = ir.DESCRIPTOR_REDUCE_KIND.MAX
1143
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1144
+
1145
+ def descriptor_atomic_and(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1146
+ self.validate_store_like(desc, value, offsets)
1147
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
1148
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1149
+ kind = ir.DESCRIPTOR_REDUCE_KIND.AND
1150
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1151
+
1152
+ def descriptor_atomic_or(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1153
+ self.validate_store_like(desc, value, offsets)
1154
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
1155
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1156
+ kind = ir.DESCRIPTOR_REDUCE_KIND.OR
1157
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1158
+
1159
+ def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1160
+ self.validate_store_like(desc, value, offsets)
1161
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
1162
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1163
+ kind = ir.DESCRIPTOR_REDUCE_KIND.XOR
1164
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1165
+
1166
+ def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy:
1167
+ assert isinstance(desc, tl.tensor_descriptor_base)
1168
+ assert cache_modifier == "", "cache modifier is not supported yet"
1169
+ assert eviction_policy == "", "eviction policy is not supported yet"
1170
+
1171
+ # Validate descriptor.
1172
+ assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1173
+ assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1174
+
1175
+ # Validate offsets.
1176
+ assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
1177
+
1178
+ # Validate minimum block size.
1179
+ assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
1180
+ dtype = desc.dtype
1181
+ min_cols = 32 // dtype.primitive_bitwidth * 8
1182
+ assert desc.block_shape[
1183
+ 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1184
+
1185
+ type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
1186
+ y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
1187
+ x = self.builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(self.builder))
1188
+ return self.tensor(x, type)
1189
+
1190
+ def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy:
1191
+ assert isinstance(desc, tl.tensor_descriptor_base)
1192
+
1193
+ # Validate descriptor.
1194
+ assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1195
+ assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1196
+
1197
+ # Validate offsets.
1198
+ assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
1199
+
1200
+ # Validate minimum block size.
1201
+ assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
1202
+ dtype = desc.dtype
1203
+ min_cols = 32 // dtype.primitive_bitwidth * 8
1204
+ assert desc.block_shape[
1205
+ 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1206
+
1207
+ y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
1208
+ self.builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
1209
+ return self.tensor(None, tl.void)
1210
+
1211
+ def _store_block_pointer(self, ptr, val, mask, boundary_check, cache, eviction):
1212
+ # Store by a block pointer: `pointer_type<block_type<>>`
1213
+ # Block pointers can not have the `mask` argument
1214
+ if mask is not None:
1215
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1216
+
1217
+ # Check same shape and element type
1218
+ block_shape = ptr.type.element_ty.get_block_shapes()
1219
+ if not val.type.is_block():
1220
+ val = self.broadcast_impl_shape(val, block_shape)
1221
+ assert val.type.is_block(), "Value argument must be block type or a scalar"
1222
+ assert block_shape == val.type.get_block_shapes(
1223
+ ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
1224
+ assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
1225
+
1226
+ elt_ty = ptr.type.element_ty.element_ty
1227
+ assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1228
+
1229
+ # Check `boundary_check` argument
1230
+ boundary_check = self._canonicalize_boundary_check(boundary_check, block_shape)
1231
+
1232
+ # Cast to target data type
1233
+ val = self.cast(val, elt_ty)
1234
+
1235
+ # Build IR
1236
+ return self.tensor(
1237
+ self.builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), tl.void)
1238
+
1239
+ def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction):
1240
+ # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1241
+ if not ptr.type.scalar.is_ptr():
1242
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
1243
+
1244
+ # Check `boundary_check` argument
1245
+ if boundary_check:
1246
+ raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
1247
+ "scalar. Because the compiler does not know the boundary; please use block pointers "
1248
+ "(defined by `make_block_ptr`) instead")
1249
+
1250
+ # For a pointer of scalar, check the type of `val` and `mask`
1251
+ if not ptr.type.is_block():
1252
+ if val.type.is_block():
1253
+ raise ValueError("Value argument cannot be block type if pointer argument is not a block")
1254
+ if mask and mask.type.is_block():
1255
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1256
+
1257
+ # Make `mask` and `val` into the same shape as `ptr`
1258
+ if ptr.type.is_block():
1259
+ val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes())
1260
+ if mask is not None:
1261
+ mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
1262
+
1263
+ ptr_ty = ptr.type.scalar
1264
+ elt_ty = ptr_ty.element_ty
1265
+
1266
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1267
+ if elt_ty == tl.int1:
1268
+ elt_ty = tl.int8
1269
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1270
+ ptr = self.cast(ptr, ptr_ty)
1271
+
1272
+ # Cast to target data type
1273
+ val = self.cast(val, elt_ty)
1274
+
1275
+ # Build IR
1276
+ if mask is None:
1277
+ return self.tensor(self.builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
1278
+ if not mask.type.scalar.is_bool():
1279
+ raise ValueError("Mask must have boolean scalar type")
1280
+ return self.tensor(self.builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction),
1281
+ tl.void)
1282
+
1283
+ def store(self, ptr: TensorTy, val: TensorTy, mask: Optional[TensorTy], boundary_check, cache_modifier: str,
1284
+ eviction_policy: str) -> TensorTy:
1285
+ # Cache and eviction options
1286
+ cache = self._str_to_store_cache_modifier(cache_modifier)
1287
+ eviction = self._str_to_eviction_policy(eviction_policy)
1288
+
1289
+ if ptr.type.is_const() or ptr.type.scalar.is_const():
1290
+ raise ValueError("Cannot store to a constant pointer")
1291
+
1292
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1293
+ # Store by a block pointer: `pointer_type<block_type<>>`
1294
+ return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction)
1295
+ else:
1296
+ # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1297
+ return self._store_legacy(ptr, val, mask, boundary_check, cache, eviction)
1298
+
1299
+ #########
1300
+ # atomic
1301
+ #########
1302
+
1303
+ def atomic_cas(self, ptr: TensorTy, cmp: TensorTy, val: TensorTy, sem: str, scope: str) -> TensorTy:
1304
+ sem = self._str_to_sem(sem)
1305
+ scope = self._str_to_scope(scope)
1306
+ element_ty = ptr.type.scalar.element_ty
1307
+ if element_ty.primitive_bitwidth not in [16, 32, 64]:
1308
+ raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
1309
+ return self.tensor(self.builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
1310
+
1311
+ def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorTy,
1312
+ op: str) -> Tuple[TensorTy, TensorTy, TensorTy]:
1313
+ if not ptr.type.scalar.is_ptr():
1314
+ raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
1315
+ if ptr.type.is_const() or ptr.type.element_ty.is_const():
1316
+ raise ValueError("Cannot store to a constant pointer")
1317
+ element_ty = ptr.type.scalar.element_ty
1318
+ if element_ty is tl.float16 and op != 'add':
1319
+ raise ValueError("atomic_" + op + " does not support fp16")
1320
+ if element_ty is tl.bfloat16 and op != 'add':
1321
+ raise ValueError("atomic_" + op + " does not support bf16")
1322
+ if element_ty in [tl.int16, tl.uint16] or element_ty.primitive_bitwidth < 16:
1323
+ raise ValueError("atomic_" + op + " does not support " + str(element_ty))
1324
+ if ptr.type.is_block():
1325
+ if mask is not None:
1326
+ mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
1327
+ if val is not None:
1328
+ val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes())
1329
+ val = self.cast(val, ptr.type.scalar.element_ty)
1330
+ if mask is None:
1331
+ mask_ir = self.builder.get_int1(True)
1332
+ mask_ty = tl.int1
1333
+ if ptr.type.is_block():
1334
+ mask_ty = ptr.type.with_element_ty(tl.int1)
1335
+ mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir)
1336
+ mask = self.tensor(mask_ir, mask_ty)
1337
+ return ptr, val, mask
1338
+
1339
+ def _signbit(self, x: TensorTy) -> TensorTy:
1340
+ bitwidth = x.dtype.primitive_bitwidth
1341
+ idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False)
1342
+ ix = self.bitcast(x, idtype)
1343
+ signbit = self.lshr(ix, bitwidth - 1)
1344
+ return self.cast(signbit, tl.int1)
1345
+
1346
+ def atomic_max(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1347
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'max')
1348
+ sem = self._str_to_sem(sem)
1349
+ scope = self._str_to_scope(scope)
1350
+ sca_ty = val.type.scalar
1351
+ # direct call to atomic_max for integers
1352
+ if sca_ty.is_int():
1353
+ if sca_ty.is_int_signed():
1354
+ return self.tensor(
1355
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope),
1356
+ val.type)
1357
+ else:
1358
+ return self.tensor(
1359
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope),
1360
+ val.type)
1361
+ # for float
1362
+ # return atomic_smax(i_ptr, i_val) if val >= 0
1363
+ # return atomic_umin(i_ptr, i_val) if val < 0
1364
+ if sca_ty not in {tl.float32, tl.float64}:
1365
+ raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
1366
+
1367
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1368
+ i_val = self.bitcast(val, i_type)
1369
+ i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
1370
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1371
+ ui_val = self.bitcast(val, ui_type)
1372
+ ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
1373
+ neg = self._signbit(val)
1374
+ pos = self.not_(neg)
1375
+ pos_ret = self.tensor(
1376
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
1377
+ self.and_(mask, pos).handle, sem, scope), i_val.type)
1378
+ neg_ret = self.tensor(
1379
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle,
1380
+ self.and_(mask, neg).handle, sem, scope), ui_val.type)
1381
+ ret = self.where(pos, pos_ret, neg_ret)
1382
+ return self.bitcast(ret, sca_ty)
1383
+
1384
+ def atomic_min(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1385
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'min')
1386
+ sem = self._str_to_sem(sem)
1387
+ scope = self._str_to_scope(scope)
1388
+ sca_ty = val.type.scalar
1389
+ # direct call to atomic_min for integers
1390
+ if sca_ty.is_int():
1391
+ if sca_ty.is_int_signed():
1392
+ return self.tensor(
1393
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope),
1394
+ val.type)
1395
+ else:
1396
+ return self.tensor(
1397
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope),
1398
+ val.type)
1399
+ # for float
1400
+ # return atomic_smin(i_ptr, i_val) if val >= 0
1401
+ # return atomic_umax(i_ptr, i_val) if val < 0
1402
+ if sca_ty not in {tl.float32, tl.float64}:
1403
+ raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
1404
+
1405
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1406
+ i_val = self.bitcast(val, i_type)
1407
+ i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
1408
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1409
+ ui_val = self.bitcast(val, ui_type)
1410
+ ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
1411
+ neg = self._signbit(val)
1412
+ pos = self.not_(neg)
1413
+ pos_ret = self.tensor(
1414
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
1415
+ self.and_(mask, pos).handle, sem, scope), i_val.type)
1416
+ neg_ret = self.tensor(
1417
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle,
1418
+ self.and_(mask, neg).handle, sem, scope), ui_ptr.type)
1419
+ ret = self.where(pos, pos_ret, neg_ret)
1420
+ return self.bitcast(ret, sca_ty)
1421
+
1422
+ def atomic_add(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1423
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'add')
1424
+ sem = self._str_to_sem(sem)
1425
+ scope = self._str_to_scope(scope)
1426
+ sca_ty = val.type.scalar
1427
+ op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
1428
+ return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope),
1429
+ val.type)
1430
+
1431
+ def atomic_and(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1432
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'and')
1433
+ sem = self._str_to_sem(sem)
1434
+ scope = self._str_to_scope(scope)
1435
+ return self.tensor(
1436
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1437
+
1438
+ def atomic_or(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1439
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'or')
1440
+ sem = self._str_to_sem(sem)
1441
+ scope = self._str_to_scope(scope)
1442
+ return self.tensor(
1443
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1444
+
1445
+ def atomic_xor(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1446
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xor')
1447
+ sem = self._str_to_sem(sem)
1448
+ scope = self._str_to_scope(scope)
1449
+ return self.tensor(
1450
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1451
+
1452
+ def atomic_xchg(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1453
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xchg')
1454
+ sem = self._str_to_sem(sem)
1455
+ scope = self._str_to_scope(scope)
1456
+ return self.tensor(
1457
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
1458
+ val.type)
1459
+
1460
+ # ===----------------------------------------------------------------------===//
1461
+ # Linear Algebra
1462
+ # ===----------------------------------------------------------------------===//
1463
+
1464
+ def _str_to_dot_input_precision(self, input_precision):
1465
+ assert input_precision.lower() in self.builder.options.allowed_dot_input_precisions, \
1466
+ f"input_precision must be one of {self.builder.options.allowed_dot_input_precisions}. Got {input_precision}"
1467
+ input_precision = input_precision.upper()
1468
+ if input_precision == "TF32X3":
1469
+ input_precision = "TF32x3"
1470
+ return getattr(ir.INPUT_PRECISION, input_precision)
1471
+
1472
+ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],
1473
+ max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy:
1474
+ assert lhs.type.is_block() and rhs.type.is_block()
1475
+
1476
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1477
+ # All combinations of supported fp8 x fp8 are permitted
1478
+ pass
1479
+ else:
1480
+ assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
1481
+ tl.float64), f"Unsupported lhs dtype {lhs.dtype}"
1482
+ assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
1483
+ tl.float64), f"Unsupported rhs dtype {rhs.dtype}"
1484
+ assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
1485
+
1486
+ if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1487
+ if "fp8e4b15" in self.builder.options.deprecated_fp8_dot_operand_dtypes:
1488
+ warnings.warn(
1489
+ "the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release"
1490
+ )
1491
+ # We upcast because there's no fp8e4b15 type in MLIR
1492
+ lhs = self.cast(lhs, tl.float16)
1493
+ rhs = self.cast(rhs, tl.float16)
1494
+
1495
+ uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8()
1496
+ uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16()
1497
+ if uses_fp8e4b8 or uses_fp8e5b16:
1498
+ type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16"
1499
+ if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes:
1500
+ arch = self.builder.options.arch
1501
+ warnings.warn(
1502
+ f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. "
1503
+ f"Please use OCP fp8 variants on {arch} for performance")
1504
+ lhs = self.cast(lhs, tl.float16)
1505
+ rhs = self.cast(rhs, tl.float16)
1506
+
1507
+ if input_precision is None:
1508
+ input_precision = self.builder.options.default_dot_input_precision
1509
+
1510
+ input_precision = self._str_to_dot_input_precision(input_precision)
1511
+
1512
+ lhs_rank = len(lhs.shape)
1513
+ rhs_rank = len(rhs.shape)
1514
+ assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1515
+ assert lhs.shape[-1].value == rhs.shape[
1516
+ -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
1517
+ assert self.builder.codegen_fns.get(
1518
+ "min_dot_size") is not None, "target doesn't provide lower shape bounds for dot."
1519
+ min_dot_size = self.builder.codegen_fns["min_dot_size"](lhs.type, rhs.type)
1520
+ assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
1521
+ and rhs.shape[-1].value >= min_dot_size[1], \
1522
+ f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}"
1523
+ if lhs.type.scalar.is_int():
1524
+ assert lhs.type.scalar == tl.int8, "only int8 supported!"
1525
+ _0 = self.builder.get_int32(0)
1526
+ ret_scalar_ty = tl.int32
1527
+ elif out_dtype.is_bf16():
1528
+ raise ValueError(
1529
+ "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`"
1530
+ )
1531
+ elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
1532
+ _0 = self.builder.get_fp32(0)
1533
+ ret_scalar_ty = tl.float32
1534
+ elif lhs.type.scalar.is_fp64():
1535
+ _0 = self.builder.get_fp64(0)
1536
+ ret_scalar_ty = tl.float64
1537
+ else:
1538
+ _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0)
1539
+ ret_scalar_ty = out_dtype
1540
+
1541
+ M = lhs.type.shape[-2]
1542
+ N = rhs.type.shape[-1]
1543
+ K = lhs.type.shape[-1]
1544
+ B = lhs.type.shape[0] if lhs_rank == 3 else None
1545
+ ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
1546
+ if acc is None:
1547
+ acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
1548
+ else:
1549
+ acc_handle = acc.handle
1550
+ assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
1551
+
1552
+ # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1553
+ if max_num_imprecise_acc is None:
1554
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1555
+ max_num_imprecise_acc = self.builder.options.max_num_imprecise_acc_default
1556
+ else:
1557
+ max_num_imprecise_acc = 0
1558
+ else:
1559
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K:
1560
+ raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
1561
+
1562
+ return self.tensor(
1563
+ self.builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty)
1564
+
1565
+ def _str_to_fp_type(self, float_format: str):
1566
+ ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
1567
+ if ty_enum is None:
1568
+ raise ValueError(f"Invalid float format: {float_format}.")
1569
+ return ty_enum
1570
+
1571
+ def _bitcast_to_fp_type(self, val: TensorTy, float_format: str):
1572
+ """
1573
+ If float_format is subbyte, make sure it's packed as uint8 and return it.
1574
+ Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
1575
+ """
1576
+ triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16":
1577
+ tl.float16}.get(float_format)
1578
+ if triton_ty is None:
1579
+ assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
1580
+ assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
1581
+ return val
1582
+ if val.dtype == triton_ty:
1583
+ return val
1584
+ else:
1585
+ unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format]
1586
+ assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
1587
+ return self.bitcast(val, triton_ty)
1588
+
1589
+ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy,
1590
+ rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool,
1591
+ lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy:
1592
+ assert lhs.type.is_block() and rhs.type.is_block()
1593
+ #TODO: validate types.
1594
+ lhs_rank = len(lhs.shape)
1595
+ rhs_rank = len(rhs.shape)
1596
+ assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1597
+ lhs_format: str = lhs_format.value
1598
+ rhs_format: str = rhs_format.value
1599
+ lhs_format_enum = self._str_to_fp_type(lhs_format)
1600
+ rhs_format_enum = self._str_to_fp_type(rhs_format)
1601
+ allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
1602
+ assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
1603
+ assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
1604
+ rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
1605
+ lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
1606
+ lhs = self._bitcast_to_fp_type(lhs, lhs_format)
1607
+ rhs = self._bitcast_to_fp_type(rhs, rhs_format)
1608
+
1609
+ assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
1610
+ assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
1611
+ M, K_LHS = lhs.type.shape[-2:]
1612
+ K_RHS, N = rhs.type.shape[-2:]
1613
+ PACKED_A = 2 if lhs_format == "e2m1" else 1
1614
+ PACKED_B = 2 if rhs_format == "e2m1" else 1
1615
+ PACKED_A_DIM = PACKED_A * K_LHS if lhs_k_pack else K_LHS
1616
+ PACKED_B_DIM = PACKED_B * K_RHS if rhs_k_pack else K_RHS
1617
+ assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1618
+ #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
1619
+ B = lhs.type.shape[0] if lhs_rank == 3 else None
1620
+ if not lhs_k_pack:
1621
+ M = M * PACKED_A
1622
+ if not rhs_k_pack:
1623
+ N = N * PACKED_B
1624
+ ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
1625
+ _0 = self.builder.get_fp32(0)
1626
+ if acc is None:
1627
+ acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
1628
+ else:
1629
+ acc_handle = acc.handle
1630
+ assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
1631
+ rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
1632
+ lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
1633
+ return self.tensor(
1634
+ self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
1635
+ rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty)
1636
+
1637
+ # ===----------------------------------------------------------------------===//
1638
+ # Indexing
1639
+ # ===----------------------------------------------------------------------===//
1640
+
1641
+ def where(self, condition: TensorTy, x: TensorTy, y: TensorTy) -> TensorTy:
1642
+ if condition.dtype != tl.int1:
1643
+ warnings.warn(
1644
+ f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}"
1645
+ )
1646
+ condition = self.cast(condition, tl.int1)
1647
+ x, y = self.binary_op_type_checking_impl(x, y, True, True)
1648
+ # x, y are broadcasted
1649
+ if condition.type.is_block():
1650
+ condition, x = self.broadcast_impl_value(condition, x)
1651
+ x, y = self.broadcast_impl_value(x, y)
1652
+ else:
1653
+ condition, _ = self.broadcast_impl_value(condition, x)
1654
+ ret_ty = x.type
1655
+ return self.tensor(self.builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
1656
+
1657
+ # ===----------------------------------------------------------------------===//
1658
+ # Reduction
1659
+ # ===----------------------------------------------------------------------===
1660
+
1661
+ def wrap_tensor(self, x, scalar_ty, ret_shape):
1662
+ if ret_shape:
1663
+ res_ty = tl.block_type(scalar_ty, ret_shape)
1664
+ else:
1665
+ # 0d-tensor -> scalar
1666
+ res_ty = scalar_ty
1667
+ return self.tensor(x, res_ty)
1668
+
1669
+ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
1670
+ if axis is None:
1671
+ inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=True) for t in inputs)
1672
+ axis = 0
1673
+ # get result shape
1674
+ shape = inputs[0].type.shape
1675
+ rank = len(shape)
1676
+ assert axis < rank, f"reduction axis must be < inputs rank ({rank})"
1677
+ ret_shape = [s for i, s in enumerate(shape) if i != axis]
1678
+ assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
1679
+
1680
+ reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
1681
+ region_builder_fn(reduce_op)
1682
+ assert reduce_op.verify()
1683
+
1684
+ return tuple(
1685
+ self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
1686
+
1687
+ # ===----------------------------------------------------------------------===
1688
+ # Associative Scan
1689
+ # ===----------------------------------------------------------------------===
1690
+
1691
+ def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
1692
+ reverse: bool) -> Tuple[TensorTy, ...]:
1693
+ shape = inputs[0].type.shape
1694
+ rank = len(shape)
1695
+
1696
+ assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
1697
+
1698
+ if axis < 0:
1699
+ axis += rank
1700
+
1701
+ for t in inputs:
1702
+ assert t.type.shape == shape, "all scan inputs must have the same shape"
1703
+
1704
+ scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
1705
+ region_builder_fn(scan_op)
1706
+ assert scan_op.verify()
1707
+
1708
+ return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
1709
+
1710
+ # ===----------------------------------------------------------------------===
1711
+ # Gather
1712
+ # ===----------------------------------------------------------------------===
1713
+
1714
+ def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
1715
+ assert index.dtype.is_int(), "index must be an integer tensor"
1716
+
1717
+ rank = len(src.type.shape)
1718
+ assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
1719
+
1720
+ assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})"
1721
+ if axis < 0:
1722
+ axis += rank
1723
+
1724
+ for d in range(rank):
1725
+ if d == axis:
1726
+ continue
1727
+ assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
1728
+
1729
+ gather = self.builder.create_gather(src.handle, index.handle, axis)
1730
+ return self.wrap_tensor(gather, src.type.scalar, index.type.shape)
1731
+
1732
+ # ===----------------------------------------------------------------------===
1733
+ # Map Elementwise
1734
+ # ===----------------------------------------------------------------------===
1735
+
1736
+ def broadcast_tensors(self, *inputs):
1737
+ if not inputs:
1738
+ return ()
1739
+ head, *tail = inputs
1740
+ for i in range(len(tail)):
1741
+ head, tail[i] = self.broadcast_impl_value(head, tail[i])
1742
+ for i in range(len(tail)):
1743
+ head, tail[i] = self.broadcast_impl_value(head, tail[i])
1744
+ return (head, *tail)
1745
+
1746
+ def map_elementwise(self, inputs: Sequence[tl.tensor], result_types: Sequence[tl.dtype], pack: int,
1747
+ region_builder_fn) -> Tuple[tl.tensor, ...]:
1748
+ inputs = self.broadcast_tensors(*inputs)
1749
+
1750
+ assert len(inputs) > 0, "map_elementwise must have at least 1 input tensor"
1751
+ result_types = [inputs[0].type.with_element_ty(ty.scalar) for ty in result_types]
1752
+ elementwise_op = self.builder.create_map_elementwise(
1753
+ [t.handle for t in inputs],
1754
+ [ty.to_ir(self.builder) for ty in result_types],
1755
+ pack,
1756
+ )
1757
+ region_builder_fn(elementwise_op)
1758
+ # assert elementwise_op.verify()
1759
+
1760
+ return tuple(self.tensor(elementwise_op.get_result(i), ty) for i, ty in enumerate(result_types))
1761
+
1762
+
1763
+ # ===----------------------------------------------------------------------===
1764
+ # Histogram
1765
+ # ===----------------------------------------------------------------------===
1766
+
1767
+ def histogram(self, input: TensorTy, num_bins: int, mask: Optional[TensorTy]) -> TensorTy:
1768
+ assert len(input.shape) == 1, "histogram only supports 1D input"
1769
+ assert input.dtype.is_int(), "histogram only supports integer input"
1770
+ if mask is not None:
1771
+ mask = self.broadcast_impl_shape(mask, input.shape)
1772
+ if not mask.type.scalar.is_bool():
1773
+ raise ValueError("Mask must have boolean scalar type")
1774
+ mask = mask.handle
1775
+ return self.tensor(self.builder.create_histogram(input.handle, num_bins, mask),
1776
+ tl.block_type(tl.int32, [num_bins]))
1777
+
1778
+ def multiple_of(self, x: TensorTy, values: List[int]) -> TensorTy:
1779
+ if max(1, len(x.shape)) != len(values):
1780
+ raise ValueError("Shape of input to multiple_of does not match the length of values")
1781
+ x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
1782
+ return x
1783
+
1784
+ def max_contiguous(self, x: TensorTy, values: List[int]) -> TensorTy:
1785
+ if len(x.shape) != len(values):
1786
+ raise ValueError("Shape of input to max_contiguous does not match the length of values")
1787
+ x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
1788
+ return x
1789
+
1790
+ def max_constancy(self, x: TensorTy, values: List[int]) -> TensorTy:
1791
+ if len(x.shape) != len(values):
1792
+ raise ValueError("Shape of input to max_constancy does not match the length of values")
1793
+ x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
1794
+ return x
1795
+
1796
+ def debug_barrier(self) -> TensorTy:
1797
+ return self.tensor(self.builder.create_barrier(), tl.void)
1798
+
1799
+ def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy:
1800
+ # It makes sense visually for prefix to end in ": "; make it so. Also,
1801
+ # non-empty prefixes should start with " ".
1802
+ if not prefix.endswith(" ") and args:
1803
+ prefix += " "
1804
+ if not prefix.endswith(": ") and args:
1805
+ prefix = prefix[:-1] + ": "
1806
+ if len(prefix) > 2 and not prefix.startswith(" "):
1807
+ prefix = " " + prefix
1808
+
1809
+ new_args = [arg.handle for arg in args]
1810
+ is_signed = [arg.dtype.is_int_signed() for arg in args]
1811
+ return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void)
1812
+
1813
+ def device_assert(self, cond: TensorTy, msg: str, mask: Optional[TensorTy]) -> TensorTy:
1814
+ if not self.builder.options.debug:
1815
+ return
1816
+ if mask is not None:
1817
+ cond = self.or_(cond, self.not_(mask))
1818
+ return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void)
1819
+
1820
+ def assume(self, cond) -> TensorTy:
1821
+ return self.tensor(self.builder.create_assume(cond.handle), tl.void)
1822
+
1823
+ def _convert_elem_to_ir_value(self, elem, require_i64):
1824
+ if isinstance(elem, int):
1825
+ elem = tl.constexpr(elem)
1826
+ if isinstance(elem, tl.constexpr):
1827
+ if isinstance(elem.value, bool):
1828
+ return self.builder.get_int1(elem.value)
1829
+ if require_i64:
1830
+ assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
1831
+ f"got a value {elem.value} which is out of the range"
1832
+ return self.builder.get_int64(elem.value)
1833
+ else:
1834
+ assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
1835
+ f"got a value {elem.value} which is out of the range"
1836
+ return self.builder.get_int32(elem.value)
1837
+ elif isinstance(elem, tl.tensor):
1838
+ assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
1839
+ assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
1840
+ if elem.dtype != tl.int64 and require_i64:
1841
+ return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(),
1842
+ elem.dtype.is_int_signed())
1843
+ elif elem.dtype == tl.int64 and not require_i64:
1844
+ assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
1845
+ "add a `.to(tl.int32)` or use regular indexing for 64 bit support"
1846
+ return elem.handle
1847
+ assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
1848
+
1849
+ def _convert_to_ir_values(self, list_like, require_i64=True):
1850
+ if hasattr(list_like, "__iter__"):
1851
+ return [self._convert_elem_to_ir_value(elem, require_i64) for elem in list_like]
1852
+ return [self._convert_elem_to_ir_value(list_like, require_i64)]
1853
+
1854
+ def make_block_ptr(self, base: TensorTy, shape, strides, offsets, block_shape, order) -> TensorTy:
1855
+ # Convert dynamic arguments to IR values
1856
+ # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
1857
+ shape = self._convert_to_ir_values(shape)
1858
+ strides = self._convert_to_ir_values(strides)
1859
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1860
+
1861
+ # Check `base` type
1862
+ if not base.type.is_ptr() or base.type.element_ty.is_block():
1863
+ raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
1864
+
1865
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1866
+ if base.type.element_ty == tl.int1:
1867
+ base = self.cast(base, tl.pointer_type(tl.int8, base.type.address_space))
1868
+
1869
+ # Check whether `block_shape` is static
1870
+ if not hasattr(block_shape, "__iter__"):
1871
+ block_shape = [block_shape]
1872
+ block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
1873
+ assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \
1874
+ "Expected a list of constant integers (`int32_t` range) in `block_shape`"
1875
+
1876
+ # Check `order`
1877
+ if not hasattr(order, "__iter__"):
1878
+ order = [order]
1879
+ order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
1880
+ assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
1881
+
1882
+ # Must have same length
1883
+ assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \
1884
+ "Expected shape/strides/offsets/block_shape to have the same length"
1885
+
1886
+ # Build value, the type is:
1887
+ # `pointer_type<blocked<shape, element_type>>` in Python
1888
+ # `tt.ptr<tensor<shape, element_type>>` in MLIR
1889
+ handle = self.builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
1890
+ return self.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
1891
+
1892
+ def advance(self, base: TensorTy, offsets) -> TensorTy:
1893
+ # Convert dynamic offsets to IR values
1894
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1895
+
1896
+ # Advanced block pointer type is the same as before
1897
+ return self.tensor(self.builder.create_advance(base.handle, offsets), base.type)
1898
+
1899
+ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: List[TensorTy],
1900
+ block_shape: List[tl.constexpr], padding_option: str = "zero") -> tl.tensor_descriptor:
1901
+ ndim = len(shape)
1902
+ if not (1 <= ndim <= 5):
1903
+ raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions")
1904
+ if len(strides) != ndim:
1905
+ raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
1906
+ if len(block_shape) != ndim:
1907
+ raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
1908
+ assert isinstance(base.dtype, tl.pointer_type)
1909
+ elem_size = base.dtype.element_ty.primitive_bitwidth // 8
1910
+ contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1])
1911
+ if contig_dim_size * elem_size < 16:
1912
+ raise ValueError(
1913
+ f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
1914
+ )
1915
+
1916
+ last_stride = tl._unwrap_if_constexpr(strides[-1])
1917
+ if last_stride != 1:
1918
+ raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}")
1919
+
1920
+ shape = [self.make_scalar(x, tl.int32) for x in shape]
1921
+ strides = [self.make_scalar(tl._unwrap_if_constexpr(x), tl.int64) for x in strides]
1922
+
1923
+ # Check whether `block_shape` is static
1924
+ block_shape = tl._unwrap_shape(block_shape)
1925
+
1926
+ assert isinstance(base.type, tl.pointer_type)
1927
+ type = tl.block_type(base.type.element_ty, block_shape)
1928
+ base_handle = base.handle
1929
+ is_signed_int = base.type.element_ty.is_int_signed()
1930
+
1931
+ padding = self._str_to_padding_option(padding_option)
1932
+
1933
+ if base.type.element_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
1934
+ raise ValueError("Padding option `nan` is not supported for integer blocks")
1935
+
1936
+ handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape],
1937
+ [s.handle for s in strides], block_shape, is_signed_int,
1938
+ padding)
1939
+ return tl.tensor_descriptor(handle, shape, strides, type)