triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,1621 @@
1
+ from __future__ import annotations # remove after python 3.11
2
+
3
+ from typing import List, Optional, Sequence, Tuple, TypeVar
4
+
5
+ from .._C.libtriton import ir
6
+ from . import core as tl
7
+ from . import math
8
+
9
+ T = TypeVar('T')
10
+
11
+
12
+ class IncompatibleTypeErrorImpl(Exception):
13
+
14
+ def __init__(self, type_a, type_b):
15
+ self.type_a = type_a
16
+ self.type_b = type_b
17
+ self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
18
+ super(IncompatibleTypeErrorImpl, self).__init__(self.message)
19
+
20
+
21
+ # ===----------------------------------------------------------------------===##
22
+ # Programming Model
23
+ # ===----------------------------------------------------------------------===##
24
+
25
+
26
+ def program_id(axis: int, builder: ir.builder) -> tl.tensor:
27
+ if axis not in (0, 1, 2):
28
+ raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
29
+ return tl.tensor(builder.create_get_program_id(axis), tl.int32)
30
+
31
+
32
+ def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
33
+ if axis not in (0, 1, 2):
34
+ raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
35
+ return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
36
+
37
+
38
+ # ===----------------------------------------------------------------------===//
39
+ # Implicit Casting Utilities
40
+ # ===----------------------------------------------------------------------===//
41
+
42
+
43
+ def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
44
+ a_rank = a_ty.int_bitwidth
45
+ b_rank = b_ty.int_bitwidth
46
+ a_sn = a_ty.int_signedness
47
+ b_sn = b_ty.int_signedness
48
+ # Rules for signedness taken from "Usual arithmetic conversions" on
49
+ # https://en.cppreference.com/w/c/language/conversion.
50
+ if a_sn == b_sn:
51
+ return a_ty if a_rank > b_rank else b_ty
52
+ elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
53
+ return a_ty if a_rank >= b_rank else b_ty
54
+ elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
55
+ return b_ty if b_rank >= a_rank else a_ty
56
+ raise TypeError(f"unexpected signedness {a_sn} and {b_sn}")
57
+
58
+
59
+ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype:
60
+ # 1) if one operand is double, the other is implicitly
61
+ # converted to double
62
+ if a_ty.is_fp64() or b_ty.is_fp64():
63
+ return tl.float64
64
+ # 2) if one operand is float, the other is implicitly
65
+ # converted to float
66
+ if a_ty.is_fp32() or b_ty.is_fp32():
67
+ return tl.float32
68
+ # 3 ) if one operand is half, the other is implicitly converted to half
69
+ # unless we're doing / or %, which do not exist natively in PTX for fp16.
70
+ # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
71
+ if a_ty.is_fp16() or b_ty.is_fp16():
72
+ if div_or_mod:
73
+ return tl.float32
74
+ else:
75
+ return tl.float16
76
+ # 4) return bf16 only if both operands are of bf16
77
+ if a_ty.is_bf16() or b_ty.is_bf16():
78
+ if div_or_mod:
79
+ return tl.float32
80
+ if a_ty.is_bf16() and b_ty.is_bf16():
81
+ return tl.bfloat16
82
+ return tl.float32
83
+ if not a_ty.is_int() or not b_ty.is_int():
84
+ raise TypeError(f"unexpected type {a_ty} and {b_ty}")
85
+ # 5 ) both operands are integer and undergo
86
+ # integer promotion
87
+ if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
88
+ raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
89
+ " because they have different signedness;"
90
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
91
+ return integer_promote_impl(a_ty, b_ty)
92
+
93
+
94
+ # ===----------------------------------------------------------------------===//
95
+ # Binary Operators
96
+ # ===----------------------------------------------------------------------===//
97
+
98
+
99
+ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
100
+ if type_a.is_ptr():
101
+ if not allow_ptr_a:
102
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
103
+ # T* + U* with T != U
104
+ if type_b.is_ptr() and (type_a != type_b):
105
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
106
+ # T* + float
107
+ if type_b.is_floating():
108
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
109
+
110
+
111
+ def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False,
112
+ allow_rhs_ptr=False, arithmetic_check=True,
113
+ div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]:
114
+ # implicit broadcasting
115
+ lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
116
+ # implicit typecasting
117
+ lhs_sca_ty = lhs.type.scalar
118
+ rhs_sca_ty = rhs.type.scalar
119
+ check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
120
+ check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
121
+ if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
122
+ ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod)
123
+ lhs = cast(lhs, ret_sca_ty, builder)
124
+ rhs = cast(rhs, ret_sca_ty, builder)
125
+ return lhs, rhs
126
+
127
+
128
+ def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
129
+ input, other = binary_op_type_checking_impl(input, other, builder, True, True)
130
+ input_scalar_ty = input.type.scalar
131
+ other_scalar_ty = other.type.scalar
132
+ if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
133
+ raise TypeError("cannot add pointers together")
134
+
135
+ # offset + ptr
136
+ # ptr + offset
137
+ if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
138
+ input, other = other, input
139
+ input_scalar_ty = input.type.scalar
140
+ other_scalar_ty = other.type.scalar
141
+ if input_scalar_ty.is_ptr():
142
+ return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
143
+ # float + float
144
+ elif input_scalar_ty.is_floating():
145
+ return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
146
+ # int + int
147
+ elif input_scalar_ty.is_int():
148
+ return tl.tensor(builder.create_add(input.handle, other.handle), input.type)
149
+ raise TypeError(f"unexpected type {input_scalar_ty}")
150
+
151
+
152
+ def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
153
+ input, other = binary_op_type_checking_impl(input, other, builder, True, False)
154
+ scalar_ty = input.type.scalar
155
+ # ptr - offset
156
+ if scalar_ty.is_ptr():
157
+ return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
158
+ # float - float
159
+ if scalar_ty.is_floating():
160
+ return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
161
+ # int - int
162
+ elif scalar_ty.is_int():
163
+ return tl.tensor(builder.create_sub(input.handle, other.handle), input.type)
164
+ raise TypeError(f"unexpected type {scalar_ty}")
165
+
166
+
167
+ def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
168
+ input, other = binary_op_type_checking_impl(input, other, builder)
169
+ scalar_ty = input.type.scalar
170
+ # float * float
171
+ if scalar_ty.is_floating():
172
+ return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type)
173
+ # * int
174
+ elif scalar_ty.is_int():
175
+ return tl.tensor(builder.create_mul(input.handle, other.handle), input.type)
176
+ raise TypeError(f"unexpected type {scalar_ty}")
177
+
178
+
179
+ def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
180
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
181
+ input_scalar_ty = input.type.scalar
182
+ other_scalar_ty = other.type.scalar
183
+ # float / int
184
+ if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
185
+ other = cast(other, input_scalar_ty, builder)
186
+ # int / float
187
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
188
+ input = cast(input, other_scalar_ty, builder)
189
+ # int / int (cast to tl.float32)
190
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
191
+ input = cast(input, tl.float32, builder)
192
+ other = cast(other, tl.float32, builder)
193
+ # float / float (cast to the highest exponent type)
194
+ elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
195
+ if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
196
+ other = cast(other, input_scalar_ty, builder)
197
+ else:
198
+ input = cast(input, other_scalar_ty, builder)
199
+ # unreachable
200
+ else:
201
+ raise TypeError(f"unexpected type {input_scalar_ty}")
202
+ return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
203
+
204
+
205
+ def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
206
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
207
+ input_scalar_ty = input.type.scalar
208
+ other_scalar_ty = other.type.scalar
209
+ if input_scalar_ty.is_int() and other_scalar_ty.is_int():
210
+ ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty)
211
+ input = cast(input, ret_ty, builder)
212
+ other = cast(other, ret_ty, builder)
213
+ if ret_ty.is_int_signed():
214
+ return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type)
215
+ else:
216
+ return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type)
217
+ raise TypeError(f"unexpected type {input_scalar_ty}")
218
+
219
+
220
+ def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor:
221
+ input_scalar_ty = input.type.scalar
222
+ other_scalar_ty = other.type.scalar
223
+ if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
224
+ raise TypeError("both operands of fdiv must have floating scalar type")
225
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
226
+ ret = builder.create_fdiv(input.handle, other.handle)
227
+ return tl.tensor(ret, input.type)
228
+
229
+
230
+ def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
231
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
232
+ scalar_ty = input.type.scalar
233
+ other_scalar_ty = other.type.scalar
234
+ # float % float
235
+ if scalar_ty.is_floating():
236
+ # input - input.div(other, rounding_mode="floor") * other
237
+ ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder)
238
+ return ret
239
+ # % int
240
+ elif scalar_ty.is_int():
241
+ if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
242
+ raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
243
+ "because they have different signedness;"
244
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
245
+ if scalar_ty.is_int_signed():
246
+ return tl.tensor(builder.create_srem(input.handle, other.handle), input.type)
247
+ else:
248
+ return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
249
+ raise TypeError(f"unexpected type {scalar_ty}")
250
+
251
+
252
+ ##############
253
+ # other arithmetic ops
254
+ ##############
255
+
256
+
257
+ def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
258
+ x, y = binary_op_type_checking_impl(x, y, builder)
259
+ dtype = x.dtype
260
+ if dtype.is_floating():
261
+ if propagate_nan == tl.PropagateNan.ALL:
262
+ return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type)
263
+ elif propagate_nan == tl.PropagateNan.NONE:
264
+ return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type)
265
+ else:
266
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
267
+ elif dtype.is_int_signed():
268
+ return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type)
269
+ elif dtype.is_int_unsigned():
270
+ return tl.tensor(builder.create_minui(x.handle, y.handle), x.type)
271
+ else:
272
+ raise TypeError(f"Unexpected dtype {dtype}")
273
+
274
+
275
+ def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
276
+ x, y = binary_op_type_checking_impl(x, y, builder)
277
+ dtype = x.dtype
278
+ if dtype.is_floating():
279
+ if propagate_nan == tl.PropagateNan.ALL:
280
+ return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type)
281
+ elif propagate_nan == tl.PropagateNan.NONE:
282
+ return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type)
283
+ else:
284
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
285
+ elif dtype.is_int_signed():
286
+ return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type)
287
+ elif dtype.is_int_unsigned():
288
+ return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type)
289
+ else:
290
+ raise TypeError(f"Unexpected dtype {dtype}")
291
+
292
+
293
+ def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
294
+ min, max = binary_op_type_checking_impl(min, max, builder)
295
+ x, min = binary_op_type_checking_impl(x, min, builder)
296
+ x, max = binary_op_type_checking_impl(x, max, builder)
297
+
298
+ dtype = x.dtype
299
+ if dtype.is_floating():
300
+ return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
301
+ else:
302
+ raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
303
+
304
+
305
+ ##############
306
+ # bitwise ops
307
+ ##############
308
+
309
+
310
+ def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor,
311
+ builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
312
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
313
+ input_sca_ty = input.type.scalar
314
+ other_sca_ty = other.type.scalar
315
+ if not input_sca_ty.is_int() or not other_sca_ty.is_int():
316
+ raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
317
+ ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
318
+ if ret_sca_ty != input_sca_ty:
319
+ input = cast(input, ret_sca_ty, builder)
320
+ if ret_sca_ty != other_sca_ty:
321
+ other = cast(other, ret_sca_ty, builder)
322
+ return input, other
323
+
324
+
325
+ def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
326
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
327
+ return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
328
+
329
+
330
+ def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
331
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
332
+ return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
333
+
334
+
335
+ def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
336
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
337
+ return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
338
+
339
+
340
+ def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
341
+ if not input.type.is_int1():
342
+ input = bitcast(input, tl.dtype("int1"), builder)
343
+ if not other.type.is_int1():
344
+ other = bitcast(other, tl.dtype("int1"), builder)
345
+ return and_(input, other, builder)
346
+
347
+
348
+ def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
349
+ if not input.type.is_int1():
350
+ input = bitcast(input, tl.dtype("int1"), builder)
351
+ if not other.type.is_int1():
352
+ other = bitcast(other, tl.dtype("int1"), builder)
353
+ return or_(input, other, builder)
354
+
355
+
356
+ def not_(input: tl.tensor, builder: ir.builder):
357
+ if not input.type.is_int1():
358
+ input = bitcast(input, tl.dtype("int1"), builder)
359
+ return invert(input, builder)
360
+
361
+
362
+ def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
363
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
364
+ return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
365
+
366
+
367
+ def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
368
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
369
+ return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
370
+
371
+
372
+ def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
373
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
374
+ return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
375
+
376
+
377
+ # ===----------------------------------------------------------------------===//
378
+ # Unary Operators
379
+ # ===----------------------------------------------------------------------===//
380
+
381
+
382
+ def plus(input: tl.tensor) -> tl.tensor:
383
+ return input
384
+
385
+
386
+ def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor:
387
+ input_sca_ty = input.type.scalar
388
+ if input_sca_ty.is_ptr():
389
+ raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
390
+ _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
391
+ return sub(_0, input, builder)
392
+
393
+
394
+ def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
395
+ input_sca_ty = input.type.scalar
396
+ if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
397
+ raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
398
+ _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
399
+ return xor_(input, _1, builder)
400
+
401
+
402
+ # ===----------------------------------------------------------------------===//
403
+ # Comparison Operators
404
+ # ===----------------------------------------------------------------------===//
405
+ def _bool_like(v: tl.tensor) -> tl.block_type:
406
+ if not v.type.is_block():
407
+ return tl.int1
408
+ shape = v.type.shape
409
+ return tl.block_type(tl.int1, shape)
410
+
411
+
412
+ def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
413
+ input, other = binary_op_type_checking_impl(input, other, builder)
414
+ scalar_ty = input.type.scalar
415
+ # float > float
416
+ if scalar_ty.is_floating():
417
+ return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input))
418
+ # > int
419
+ elif scalar_ty.is_int():
420
+ if scalar_ty.is_int_signed():
421
+ return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input))
422
+ else:
423
+ return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input))
424
+ raise TypeError(f"unexpected type {scalar_ty}")
425
+
426
+
427
+ def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
428
+ input, other = binary_op_type_checking_impl(input, other, builder)
429
+ scalar_ty = input.type.scalar
430
+ # float >= float
431
+ if scalar_ty.is_floating():
432
+ return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input))
433
+ # >= int
434
+ elif scalar_ty.is_int():
435
+ if scalar_ty.is_int_signed():
436
+ return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input))
437
+ else:
438
+ return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input))
439
+ raise TypeError(f"unexpected type {scalar_ty}")
440
+
441
+
442
+ def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
443
+ input, other = binary_op_type_checking_impl(input, other, builder)
444
+ scalar_ty = input.type.scalar
445
+ # float < float
446
+ if scalar_ty.is_floating():
447
+ return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input))
448
+ # < int
449
+ elif scalar_ty.is_int():
450
+ if scalar_ty.is_int_signed():
451
+ return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input))
452
+ else:
453
+ return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input))
454
+ raise TypeError(f"unexpected type {scalar_ty}")
455
+
456
+
457
+ def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
458
+ input, other = binary_op_type_checking_impl(input, other, builder)
459
+ scalar_ty = input.type.scalar
460
+ # float < float
461
+ if scalar_ty.is_floating():
462
+ return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input))
463
+ # < int
464
+ elif scalar_ty.is_int():
465
+ if scalar_ty.is_int_signed():
466
+ return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input))
467
+ else:
468
+ return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input))
469
+ raise TypeError(f"unexpected type {scalar_ty}")
470
+
471
+
472
+ def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
473
+ input, other = binary_op_type_checking_impl(input, other, builder)
474
+ scalar_ty = input.type.scalar
475
+ # float == float
476
+ if scalar_ty.is_floating():
477
+ return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input))
478
+ # == int
479
+ elif scalar_ty.is_int():
480
+ return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input))
481
+ raise TypeError(f"unexpected type {scalar_ty}")
482
+
483
+
484
+ def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
485
+ input, other = binary_op_type_checking_impl(input, other, builder)
486
+ scalar_ty = input.type.scalar
487
+ # float == float
488
+ if scalar_ty.is_floating():
489
+ return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input))
490
+ # == int
491
+ elif scalar_ty.is_int():
492
+ return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
493
+ raise TypeError(f"unexpected type {scalar_ty}")
494
+
495
+
496
+ # ===----------------------------------------------------------------------===//
497
+ # Block Creation
498
+ # ===----------------------------------------------------------------------===//
499
+
500
+
501
+ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
502
+ if not isinstance(start, int) or not isinstance(end, int):
503
+ raise ValueError("arange's arguments must be of type tl.constexpr")
504
+ is_start_int64 = bool(start >> 32)
505
+ is_end_int64 = bool(end >> 32)
506
+ if is_start_int64 or is_end_int64:
507
+ raise ValueError("arange must fit in int32")
508
+ if end <= start:
509
+ raise ValueError("arange's end argument must be greater than the start argument")
510
+ range = end - start
511
+ if (range & (range - 1)) != 0:
512
+ raise ValueError("arange's range must be a power of 2")
513
+ shape = [range]
514
+ ret_ty = tl.block_type(tl.int32, shape)
515
+ return tl.tensor(builder.create_make_range(start, end), ret_ty)
516
+
517
+
518
+ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
519
+ if isinstance(value, tl.tensor):
520
+ assert value.numel.value == 1, "only accepts size-1 tensor"
521
+ value = cast(value, dtype, builder)
522
+ else:
523
+ # scalar
524
+ if dtype is None:
525
+ raise ValueError("dtype must be specified when value is not a tensor")
526
+ if value == 0:
527
+ value = builder.get_null_value(dtype.to_ir(builder))
528
+ else:
529
+ get_value_fn = getattr(builder, f"get_{dtype.name}")
530
+ value = get_value_fn(value)
531
+ value = tl.tensor(value, dtype)
532
+
533
+ return splat(value, shape, builder)
534
+
535
+
536
+ # ===----------------------------------------------------------------------===//
537
+ # Shape Manipulation
538
+ # ===----------------------------------------------------------------------===//
539
+
540
+
541
+ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
542
+ assert not value.type.is_block(), "Cannot splat a block tensor"
543
+ if len(shape) == 0:
544
+ return value
545
+ ret_ty = tl.block_type(value.dtype, shape)
546
+ return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
547
+
548
+
549
+ def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor:
550
+ numel = 1
551
+ for s in dst_shape:
552
+ numel *= s
553
+ if input.type.numel != numel:
554
+ raise ValueError("reshape() cannot change total number of elements in tensor")
555
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
556
+ return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
557
+
558
+
559
+ def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
560
+ dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
561
+ dst_shape.insert(axis, 1)
562
+
563
+ if not input.type.is_block():
564
+ return splat(input, shape=dst_shape, builder=builder)
565
+
566
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
567
+ return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
568
+
569
+
570
+ def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
571
+ assert can_reorder, "current implementation of `cat` always may reorder elements"
572
+ assert len(lhs.shape) == 1
573
+ ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
574
+ return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
575
+
576
+
577
+ def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor:
578
+ a, b = broadcast_impl_value(a, b, builder)
579
+
580
+ # The IR can't handle joining two scalars, so upcast them to 1D tensors,
581
+ # then downcast the result.
582
+ was_rank_1 = a.shape == []
583
+ if was_rank_1:
584
+ a = expand_dims(a, 0, builder)
585
+ b = expand_dims(b, 0, builder)
586
+
587
+ if isinstance(a.shape[-1], tl.constexpr):
588
+ two = tl.constexpr(2)
589
+ else:
590
+ two = 2
591
+ new_shape = a.shape + [two]
592
+
593
+ ret_type = tl.block_type(a.type.scalar, new_shape)
594
+ ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type)
595
+
596
+ if was_rank_1:
597
+ ret = reshape(ret, [2], can_reorder=False, builder=builder)
598
+
599
+ return ret
600
+
601
+
602
+ def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
603
+ assert (len(a.shape) > 0)
604
+ assert (tl._constexpr_to_value(a.shape[-1]) == 2)
605
+
606
+ new_shape = a.shape[:-1]
607
+ ret_type = tl.block_type(a.type.scalar, new_shape)
608
+ outLHS, outRHS = builder.create_split(a.handle)
609
+ return (
610
+ tl.tensor(outLHS, ret_type),
611
+ tl.tensor(outRHS, ret_type),
612
+ )
613
+
614
+
615
+ def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor:
616
+ if len(input.shape) != len(dims):
617
+ raise ValueError("permute dims must have the same length as input shape")
618
+ if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))):
619
+ raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
620
+
621
+ ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims])
622
+ return tl.tensor(builder.create_trans(input.handle, dims), ret_type)
623
+
624
+
625
+ def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
626
+ if not input.type.is_block():
627
+ ret_ty = tl.block_type(input.type, shape)
628
+ return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
629
+ src_shape = input.type.get_block_shapes()
630
+ if len(src_shape) != len(shape):
631
+ raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
632
+ if shape == src_shape:
633
+ return input
634
+ for i, item in enumerate(src_shape):
635
+ if shape[i] != item and item != 1:
636
+ raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
637
+ f" must match the existing size ({item}) at non-singleton dimension"
638
+ f" {i}: {src_shape}, {shape}")
639
+ ret_ty = tl.block_type(input.type.scalar, shape)
640
+ return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
641
+
642
+
643
+ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
644
+ lhs_ty = lhs.type
645
+ rhs_ty = rhs.type
646
+
647
+ # make_shape_compatible(block, scalar)
648
+ if lhs_ty.is_block() and not rhs_ty.is_block():
649
+ rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape)
650
+ rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty)
651
+ # make_shape_compatible(scalar, block)
652
+ elif not lhs_ty.is_block() and rhs_ty.is_block():
653
+ lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape)
654
+ lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty)
655
+ # make_shape_compatible(block, block)
656
+ elif lhs_ty.is_block() and rhs_ty.is_block():
657
+ lhs_shape = lhs_ty.get_block_shapes()
658
+ rhs_shape = rhs_ty.get_block_shapes()
659
+
660
+ if len(lhs_shape) < len(rhs_shape):
661
+ # Add new axes to lhs
662
+ for _ in range(len(lhs_shape), len(rhs_shape)):
663
+ lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
664
+ tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
665
+ lhs_ty = lhs.type
666
+ lhs_shape = lhs_ty.get_block_shapes()
667
+ elif len(rhs_shape) < len(lhs_shape):
668
+ # Add new axes to rhs
669
+ for _ in range(len(rhs_shape), len(lhs_shape)):
670
+ rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
671
+ tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
672
+ rhs_ty = rhs.type
673
+ rhs_shape = rhs_ty.get_block_shapes()
674
+ assert len(rhs_shape) == len(lhs_shape)
675
+
676
+ ret_shape = []
677
+ for i, left in enumerate(lhs_shape):
678
+ right = rhs_shape[i]
679
+ if left == 1:
680
+ ret_shape.append(right)
681
+ elif (right == 1) or (right == left):
682
+ ret_shape.append(left)
683
+ else:
684
+ raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
685
+ "at index " + str(i) + ": " + str(left) + " and " + str(right))
686
+ if lhs_shape != ret_shape:
687
+ ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
688
+ lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
689
+ if rhs_shape != ret_shape:
690
+ ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
691
+ rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
692
+ # (scalar, scalar) => returns original blocks
693
+ return lhs, rhs
694
+
695
+
696
+ #######
697
+ # cast
698
+ #######
699
+
700
+
701
+ def _str_to_rounding_mode(rounding_mode: Optional[str]):
702
+ if rounding_mode is None:
703
+ return None
704
+ if rounding_mode == 'rtne':
705
+ return ir.ROUNDING_MODE.RTNE
706
+ if rounding_mode == 'rtz':
707
+ return ir.ROUNDING_MODE.RTZ
708
+ raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.")
709
+
710
+
711
+ def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
712
+ src_ty = input.type
713
+ if src_ty.is_block():
714
+ dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
715
+ if src_ty == dst_ty:
716
+ return input
717
+ src_sca_ty = src_ty.scalar
718
+ dst_sca_ty = dst_ty.scalar
719
+ if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
720
+ return cast(input, dst_ty, builder)
721
+ # Bitcast
722
+ src_bits = src_sca_ty.primitive_bitwidth
723
+ dst_bits = dst_sca_ty.primitive_bitwidth
724
+ if src_bits != dst_bits:
725
+ raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
726
+ "data-type of size " + str(dst_bits))
727
+ return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
728
+
729
+
730
+ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder,
731
+ fp_downcast_rounding: Optional[str] = None) -> tl.tensor:
732
+ src_ty = input.type
733
+ if isinstance(dst_ty, tl.constexpr):
734
+ dst_ty = dst_ty.value
735
+ if isinstance(fp_downcast_rounding, tl.constexpr):
736
+ fp_downcast_rounding = fp_downcast_rounding.value
737
+ if src_ty.is_block():
738
+ dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
739
+ if src_ty == dst_ty:
740
+ return input
741
+
742
+ src_sca_ty = src_ty.scalar
743
+ dst_sca_ty = dst_ty.scalar
744
+
745
+ # For fp downcasting default rounding mode should be RTNE, for all other conversions it should
746
+ # not be set
747
+ fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding)
748
+ use_custom_rounding = False
749
+ if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
750
+ ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
751
+ if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
752
+ elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
753
+ else:
754
+ if fp_downcast_rounding is not None:
755
+ raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
756
+ "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty))
757
+
758
+ if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()):
759
+ assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89"
760
+
761
+ if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
762
+ assert builder.codegen_fns.get(
763
+ "convert_custom_types") is not None, "target doesn't provide conversion for this type."
764
+ return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder)
765
+ # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
766
+ # and non-default rounding modes for downcasting
767
+ if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
768
+ (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
769
+ use_custom_rounding:
770
+ return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty)
771
+
772
+ # bf16 <=> (not fp32)
773
+ if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
774
+ (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
775
+ return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
776
+
777
+ # Standard floating types' casting: truncation
778
+ # fp64 => fp32, fp16, bf16
779
+ # fp32 => fp16, bf16
780
+ truncate_fp = src_sca_ty.is_floating() and \
781
+ dst_sca_ty.is_floating() and \
782
+ src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
783
+ if truncate_fp:
784
+ return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
785
+
786
+ # Standard floating types' casting: extension
787
+ # fp32 => fp64
788
+ # fp16 => fp32, fp64
789
+ # bf16 => fp32, fp64
790
+ ext_fp = src_sca_ty.is_floating() and \
791
+ dst_sca_ty.is_floating() and \
792
+ src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
793
+ if ext_fp:
794
+ return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
795
+
796
+ # Casting between integer types
797
+ if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
798
+ (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
799
+ sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
800
+ if dst_sca_ty.is_bool():
801
+ ty = input.dtype.to_ir(builder)
802
+ _0 = tl.tensor(builder.get_null_value(ty), input.dtype)
803
+ return not_equal(input, _0, builder)
804
+ else:
805
+ return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty)
806
+
807
+ # Casting standard floating types to integer types
808
+ if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
809
+ if dst_sca_ty.is_bool():
810
+ ty = input.dtype.to_ir(builder)
811
+ _0 = tl.tensor(builder.get_null_value(ty), input.dtype)
812
+ return not_equal(input, _0, builder)
813
+ elif dst_sca_ty.is_int_signed():
814
+ return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty)
815
+ else:
816
+ return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)
817
+
818
+ # Casting integer types to standard floating types
819
+ if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
820
+ if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
821
+ return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
822
+ else:
823
+ return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
824
+
825
+ # Casting pointer types to integer types
826
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
827
+ bitwidth = dst_sca_ty.int_bitwidth
828
+ if bitwidth == 64:
829
+ return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty)
830
+ if bitwidth == 1:
831
+ return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
832
+
833
+ # Casting integer types to pointer types
834
+ if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
835
+ return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
836
+
837
+ # Casting pointer types to pointer types
838
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
839
+ return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
840
+
841
+ assert False, f'cannot cast {input} to {dst_ty}'
842
+
843
+
844
+ # ===----------------------------------------------------------------------===//
845
+ # Memory Operators
846
+ # ===----------------------------------------------------------------------===//
847
+
848
+
849
+ def _str_to_load_cache_modifier(cache_modifier):
850
+ cache = ir.CACHE_MODIFIER.NONE # default
851
+ if cache_modifier:
852
+ if cache_modifier == ".ca":
853
+ cache = ir.CACHE_MODIFIER.CA
854
+ elif cache_modifier == ".cg":
855
+ cache = ir.CACHE_MODIFIER.CG
856
+ else:
857
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
858
+ return cache
859
+
860
+
861
+ def _str_to_store_cache_modifier(cache_modifier):
862
+ cache = ir.CACHE_MODIFIER.NONE # default
863
+ if cache_modifier:
864
+ if cache_modifier == ".wb":
865
+ cache = ir.CACHE_MODIFIER.WB
866
+ elif cache_modifier == ".cg":
867
+ cache = ir.CACHE_MODIFIER.CG
868
+ elif cache_modifier == ".cs":
869
+ cache = ir.CACHE_MODIFIER.CS
870
+ elif cache_modifier == ".wt":
871
+ cache = ir.CACHE_MODIFIER.WT
872
+ else:
873
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
874
+ return cache
875
+
876
+
877
+ def _str_to_eviction_policy(eviction_policy):
878
+ eviction = ir.EVICTION_POLICY.NORMAL # default
879
+ if eviction_policy:
880
+ if eviction_policy == "evict_last":
881
+ eviction = ir.EVICTION_POLICY.EVICT_LAST
882
+ elif eviction_policy == "evict_first":
883
+ eviction = ir.EVICTION_POLICY.EVICT_FIRST
884
+ else:
885
+ raise ValueError(f"Eviction policy {eviction_policy} not supported")
886
+ return eviction
887
+
888
+
889
+ def _str_to_padding_option(padding_option):
890
+ padding = None # default
891
+ if padding_option:
892
+ if padding_option == "zero":
893
+ padding = ir.PADDING_OPTION.PAD_ZERO
894
+ elif padding_option == "nan":
895
+ padding = ir.PADDING_OPTION.PAD_NAN
896
+ else:
897
+ raise ValueError(f"Padding option {padding_option} not supported")
898
+ return padding
899
+
900
+
901
+ def _str_to_sem(sem_option):
902
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
903
+ if sem_option:
904
+ if sem_option == "acquire":
905
+ sem = ir.MEM_SEMANTIC.ACQUIRE
906
+ elif sem_option == "release":
907
+ sem = ir.MEM_SEMANTIC.RELEASE
908
+ elif sem_option == "acq_rel":
909
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
910
+ elif sem_option == "relaxed":
911
+ sem = ir.MEM_SEMANTIC.RELAXED
912
+ else:
913
+ raise ValueError(f"Memory semantic {sem_option} not supported")
914
+ return sem
915
+
916
+
917
+ def _str_to_scope(scope_option):
918
+ scope = ir.MEM_SYNC_SCOPE.GPU
919
+ if scope_option:
920
+ if scope_option == "gpu":
921
+ scope = ir.MEM_SYNC_SCOPE.GPU
922
+ elif scope_option == "cta":
923
+ scope = ir.MEM_SYNC_SCOPE.CTA
924
+ elif scope_option == "sys":
925
+ scope = ir.MEM_SYNC_SCOPE.SYSTEM
926
+ else:
927
+ raise ValueError(f"Memory semantic {scope_option} not supported")
928
+ return scope
929
+
930
+
931
+ def _canonicalize_boundary_check(boundary_check, block_shape):
932
+ if boundary_check:
933
+ if not hasattr(boundary_check, "__iter__"):
934
+ boundary_check = [boundary_check]
935
+ boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
936
+ for dim in boundary_check:
937
+ assert isinstance(dim, int) and 0 <= dim < len(block_shape)
938
+ assert len(boundary_check) > 0
939
+ assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
940
+ return sorted(boundary_check)
941
+ return ()
942
+
943
+
944
+ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
945
+ # Load by a block pointer: `pointer_type<block_type<>>`
946
+ # Block pointer can not have `mask` and `other` arguments
947
+ if mask is not None or other is not None:
948
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
949
+
950
+ elt_ty = ptr.type.element_ty.element_ty
951
+ assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
952
+ if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
953
+ raise ValueError("Padding option `nan` is not supported for integer block pointers")
954
+
955
+ # `dst_ty` is de-referenced type of the pointer type
956
+ dst_ty = ptr.type.element_ty
957
+
958
+ # Check `boundary_check` argument
959
+ boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
960
+
961
+ # Build IR
962
+ return tl.tensor(
963
+ builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
964
+
965
+
966
+ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
967
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
968
+ if not ptr.type.scalar.is_ptr():
969
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
970
+
971
+ # Check `mask`, `other`, `boundary_check`, and `padding` arguments
972
+ if mask is None and other is not None:
973
+ raise ValueError("`other` cannot be provided without `mask`")
974
+ if padding or boundary_check:
975
+ raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
976
+ "pointers or loading a scalar. Because the compiler does not know the boundary; please "
977
+ "use block pointers (defined by `make_block_ptr`) instead")
978
+
979
+ # For a pointer of scalar, check the type of `mask` and `other`
980
+ if not ptr.type.is_block():
981
+ if mask and mask.type.is_block():
982
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
983
+ if other and other.type.is_block():
984
+ raise ValueError("Other argument cannot be block type if pointer argument is not a block")
985
+
986
+ # Make `mask` and `other` into the same shape as `ptr`
987
+ if ptr.type.is_block():
988
+ if mask is not None:
989
+ mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
990
+ if other is not None:
991
+ other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
992
+
993
+ # Get `pointer_type<elt_ty>` and `elt_ty`
994
+ ptr_ty = ptr.type.scalar
995
+ elt_ty = ptr_ty.element_ty
996
+
997
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
998
+ if elt_ty == tl.int1:
999
+ elt_ty = tl.int8
1000
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1001
+ ptr = cast(ptr, ptr_ty, builder)
1002
+
1003
+ # Cast `other` into `ele_ty` type
1004
+ if other is not None:
1005
+ other = cast(other, elt_ty, builder)
1006
+
1007
+ # Create loaded result type `dst_ty`
1008
+ if ptr.type.is_block():
1009
+ shape = ptr.type.get_block_shapes()
1010
+ dst_ty = tl.block_type(elt_ty, shape)
1011
+ else:
1012
+ # Load by de-referencing the pointer of scalar
1013
+ dst_ty = elt_ty
1014
+
1015
+ # Build IR
1016
+ if mask is None:
1017
+ return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
1018
+ else:
1019
+ return tl.tensor(
1020
+ builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
1021
+ is_volatile), dst_ty)
1022
+
1023
+
1024
+ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple,
1025
+ padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool,
1026
+ builder: ir.builder) -> tl.tensor:
1027
+ # Cache, eviction and padding options
1028
+ cache = _str_to_load_cache_modifier(cache_modifier)
1029
+ eviction = _str_to_eviction_policy(eviction_policy)
1030
+ padding = _str_to_padding_option(padding_option)
1031
+
1032
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1033
+ # Load by a block pointer: `pointer_type<block_type<>>`
1034
+ return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1035
+ else:
1036
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1037
+ return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1038
+
1039
+
1040
+ def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type,
1041
+ builder: ir.builder) -> tl.tensor:
1042
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1043
+ x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder),
1044
+ _str_to_load_cache_modifier(cache_modifier),
1045
+ _str_to_eviction_policy(eviction_policy))
1046
+ return tl.tensor(x, type)
1047
+
1048
+
1049
+ def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1050
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1051
+ return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void)
1052
+
1053
+
1054
+ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
1055
+ # Store by a block pointer: `pointer_type<block_type<>>`
1056
+ # Block pointers can not have the `mask` argument
1057
+ if mask is not None:
1058
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1059
+
1060
+ # Check same shape and element type
1061
+ block_shape = ptr.type.element_ty.get_block_shapes()
1062
+ if not val.type.is_block():
1063
+ val = broadcast_impl_shape(val, block_shape, builder)
1064
+ assert val.type.is_block(), "Value argument must be block type or a scalar"
1065
+ assert block_shape == val.type.get_block_shapes(
1066
+ ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
1067
+ assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
1068
+
1069
+ elt_ty = ptr.type.element_ty.element_ty
1070
+ assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
1071
+
1072
+ # Check `boundary_check` argument
1073
+ boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
1074
+
1075
+ # Cast to target data type
1076
+ val = cast(val, elt_ty, builder)
1077
+
1078
+ # Build IR
1079
+ return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction),
1080
+ tl.void)
1081
+
1082
+
1083
+ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
1084
+ # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1085
+ if not ptr.type.scalar.is_ptr():
1086
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
1087
+
1088
+ # Check `boundary_check` argument
1089
+ if boundary_check:
1090
+ raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
1091
+ "scalar. Because the compiler does not know the boundary; please use block pointers "
1092
+ "(defined by `make_block_ptr`) instead")
1093
+
1094
+ # For a pointer of scalar, check the type of `val` and `mask`
1095
+ if not ptr.type.is_block():
1096
+ if val.type.is_block():
1097
+ raise ValueError("Value argument cannot be block type if pointer argument is not a block")
1098
+ if mask and mask.type.is_block():
1099
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1100
+
1101
+ # Make `mask` and `val` into the same shape as `ptr`
1102
+ if ptr.type.is_block():
1103
+ val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
1104
+ if mask is not None:
1105
+ mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1106
+
1107
+ ptr_ty = ptr.type.scalar
1108
+ elt_ty = ptr_ty.element_ty
1109
+
1110
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1111
+ if elt_ty == tl.int1:
1112
+ elt_ty = tl.int8
1113
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1114
+ ptr = cast(ptr, ptr_ty, builder)
1115
+
1116
+ # Cast to target data type
1117
+ val = cast(val, elt_ty, builder)
1118
+
1119
+ # Build IR
1120
+ if not mask:
1121
+ return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
1122
+ if not mask.type.scalar.is_bool():
1123
+ raise ValueError("Mask must have boolean scalar type")
1124
+ return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
1125
+
1126
+
1127
+ def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str,
1128
+ eviction_policy: str, builder: ir.builder) -> tl.tensor:
1129
+ # Cache and eviction options
1130
+ cache = _str_to_store_cache_modifier(cache_modifier)
1131
+ eviction = _str_to_eviction_policy(eviction_policy)
1132
+
1133
+ if ptr.type.is_const() or ptr.type.scalar.is_const():
1134
+ raise ValueError("Cannot store to a constant pointer")
1135
+
1136
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1137
+ # Store by a block pointer: `pointer_type<block_type<>>`
1138
+ return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
1139
+ else:
1140
+ # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1141
+ return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder)
1142
+
1143
+
1144
+ #########
1145
+ # atomic
1146
+ #########
1147
+
1148
+
1149
+ def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1150
+ sem = _str_to_sem(sem)
1151
+ scope = _str_to_scope(scope)
1152
+ element_ty = ptr.type.scalar.element_ty
1153
+ if element_ty.primitive_bitwidth not in [16, 32, 64]:
1154
+ raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
1155
+ return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
1156
+
1157
+
1158
+ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str,
1159
+ builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
1160
+ if not ptr.type.scalar.is_ptr():
1161
+ raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
1162
+ if ptr.type.is_const() or ptr.type.element_ty.is_const():
1163
+ raise ValueError("Cannot store to a constant pointer")
1164
+ element_ty = ptr.type.scalar.element_ty
1165
+ if element_ty is tl.float16 and op != 'add':
1166
+ raise ValueError("atomic_" + op + " does not support fp16")
1167
+ if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
1168
+ raise ValueError("atomic_" + op + " does not support " + str(element_ty))
1169
+ if ptr.type.is_block():
1170
+ if mask is not None:
1171
+ mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1172
+ if val is not None:
1173
+ val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
1174
+ val = cast(val, ptr.type.scalar.element_ty, builder)
1175
+ if not mask:
1176
+ mask_ir = builder.get_int1(True)
1177
+ mask_ty = tl.int1
1178
+ if ptr.type.is_block():
1179
+ mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes())
1180
+ mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes())
1181
+ mask = tl.tensor(mask_ir, mask_ty)
1182
+ return ptr, val, mask
1183
+
1184
+
1185
+ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1186
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
1187
+ sem = _str_to_sem(sem)
1188
+ scope = _str_to_scope(scope)
1189
+ sca_ty = val.type.scalar
1190
+ # direct call to atomic_max for integers
1191
+ if sca_ty.is_int():
1192
+ if sca_ty.is_int_signed():
1193
+ return tl.tensor(
1194
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1195
+ else:
1196
+ return tl.tensor(
1197
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1198
+ # for float
1199
+ # return atomic_smax(i_ptr, i_val) if val >= 0
1200
+ # return atomic_umin(i_ptr, i_val) if val < 0
1201
+ if sca_ty not in {tl.float32, tl.float64}:
1202
+ raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
1203
+
1204
+ zero = full([], 0.0, sca_ty, builder)
1205
+
1206
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1207
+ i_val = bitcast(val, i_type, builder)
1208
+ i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
1209
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1210
+ ui_val = bitcast(val, ui_type, builder)
1211
+ ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1212
+ pos = greater_equal(val, zero, builder)
1213
+ neg = less_than(val, zero, builder)
1214
+ pos_ret = tl.tensor(
1215
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
1216
+ and_(mask, pos, builder).handle, sem, scope), i_val.type)
1217
+ neg_ret = tl.tensor(
1218
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle,
1219
+ and_(mask, neg, builder).handle, sem, scope), ui_val.type)
1220
+ ret = where(pos, pos_ret, neg_ret, builder)
1221
+ return bitcast(ret, sca_ty, builder)
1222
+
1223
+
1224
+ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1225
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
1226
+ sem = _str_to_sem(sem)
1227
+ scope = _str_to_scope(scope)
1228
+ sca_ty = val.type.scalar
1229
+ # direct call to atomic_min for integers
1230
+ if sca_ty.is_int():
1231
+ if sca_ty.is_int_signed():
1232
+ return tl.tensor(
1233
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1234
+ else:
1235
+ return tl.tensor(
1236
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1237
+ # for float
1238
+ # return atomic_smin(i_ptr, i_val) if val >= 0
1239
+ # return atomic_umax(i_ptr, i_val) if val < 0
1240
+ if sca_ty not in {tl.float32, tl.float64}:
1241
+ raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
1242
+
1243
+ zero = full([], 0.0, sca_ty, builder)
1244
+
1245
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1246
+ i_val = bitcast(val, i_type, builder)
1247
+ i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
1248
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1249
+ ui_val = bitcast(val, ui_type, builder)
1250
+ ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1251
+ pos = greater_equal(val, zero, builder)
1252
+ neg = less_than(val, zero, builder)
1253
+ pos_ret = tl.tensor(
1254
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
1255
+ and_(mask, pos, builder).handle, sem, scope), i_val.type)
1256
+ neg_ret = tl.tensor(
1257
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle,
1258
+ and_(mask, neg, builder).handle, sem, scope), ui_ptr.type)
1259
+ ret = where(pos, pos_ret, neg_ret, builder)
1260
+ return bitcast(ret, sca_ty, builder)
1261
+
1262
+
1263
+ def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1264
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
1265
+ sem = _str_to_sem(sem)
1266
+ scope = _str_to_scope(scope)
1267
+ sca_ty = val.type.scalar
1268
+ op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
1269
+ return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1270
+
1271
+
1272
+ def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1273
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
1274
+ sem = _str_to_sem(sem)
1275
+ scope = _str_to_scope(scope)
1276
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope),
1277
+ val.type)
1278
+
1279
+
1280
+ def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1281
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
1282
+ sem = _str_to_sem(sem)
1283
+ scope = _str_to_scope(scope)
1284
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope),
1285
+ val.type)
1286
+
1287
+
1288
+ def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1289
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
1290
+ sem = _str_to_sem(sem)
1291
+ scope = _str_to_scope(scope)
1292
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope),
1293
+ val.type)
1294
+
1295
+
1296
+ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
1297
+ builder: ir.builder) -> tl.tensor:
1298
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
1299
+ sem = _str_to_sem(sem)
1300
+ scope = _str_to_scope(scope)
1301
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
1302
+ val.type)
1303
+
1304
+
1305
+ # ===----------------------------------------------------------------------===//
1306
+ # Linear Algebra
1307
+ # ===----------------------------------------------------------------------===//
1308
+
1309
+
1310
+ def _str_to_dot_input_precision(input_precision, builder):
1311
+ assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \
1312
+ f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}"
1313
+ input_precision = input_precision.upper()
1314
+ if input_precision == "TF32X3":
1315
+ input_precision = "TF32x3"
1316
+ return getattr(ir.INPUT_PRECISION, input_precision)
1317
+
1318
+
1319
+ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int,
1320
+ out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
1321
+
1322
+ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
1323
+ if not options.allow_fp8e4nv:
1324
+ assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(
1325
+ ), "Dot op does not support fp8e4nv on CUDA arch < 90"
1326
+ if lhs_dtype.is_fp8() and rhs_dtype.is_fp8():
1327
+ return
1328
+ assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
1329
+ else:
1330
+ if lhs_dtype.is_int() or rhs_dtype.is_int():
1331
+ assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})"
1332
+ assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(
1333
+ ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})"
1334
+ elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8():
1335
+ if options.allow_fp8e4b15:
1336
+ allowed_types = ['fp8e4nv', 'fp8e5', 'fp8e4b15']
1337
+ else:
1338
+ allowed_types = ['fp8e4nv', 'fp8e5']
1339
+
1340
+ def _validate_dtype(dtype, allowed_types, operand_name):
1341
+ if not any(getattr(dtype, f'is_{dtype_name}')() for dtype_name in allowed_types):
1342
+ supported_types = ', '.join(allowed_types)
1343
+ raise AssertionError(f"Only supports {supported_types}. {operand_name} ({dtype})")
1344
+
1345
+ _validate_dtype(lhs_dtype, allowed_types, "First operand")
1346
+ _validate_dtype(rhs_dtype, allowed_types, "Second operand")
1347
+ else:
1348
+ assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(
1349
+ ), f"Unsupported dtype {lhs_dtype}"
1350
+ assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(
1351
+ ), f"Unsupported dtype {rhs_dtype}"
1352
+ assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
1353
+
1354
+ assert lhs.type.is_block() and rhs.type.is_block()
1355
+ assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options)
1356
+ if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1357
+ lhs = cast(lhs, tl.float16, builder)
1358
+ rhs = cast(rhs, tl.float16, builder)
1359
+
1360
+ if input_precision is None:
1361
+ input_precision = builder.options.default_dot_input_precision
1362
+
1363
+ input_precision = _str_to_dot_input_precision(input_precision, builder)
1364
+
1365
+ lhs_rank = len(lhs.shape)
1366
+ rhs_rank = len(rhs.shape)
1367
+ assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1368
+ assert lhs.shape[-1].value == rhs.shape[
1369
+ -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
1370
+ assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \
1371
+ and rhs.shape[-1].value >= 16, \
1372
+ f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
1373
+ if lhs.type.scalar.is_int():
1374
+ assert lhs.type.scalar == tl.int8, "only int8 supported!"
1375
+ # TODO: This is CUDA specific, check if ROCm has the same limitation
1376
+ assert lhs.shape[1].value >= 32, "small blocks not supported!"
1377
+ _0 = builder.get_int32(0)
1378
+ ret_scalar_ty = tl.int32
1379
+ elif out_dtype.is_bf16():
1380
+ raise ValueError(
1381
+ "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
1382
+ elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
1383
+ _0 = builder.get_fp32(0)
1384
+ ret_scalar_ty = tl.float32
1385
+ else:
1386
+ _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0)
1387
+ ret_scalar_ty = out_dtype
1388
+
1389
+ M = lhs.type.shape[-2]
1390
+ N = rhs.type.shape[-1]
1391
+ B = lhs.type.shape[0] if lhs_rank == 3 else None
1392
+ ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
1393
+ if acc is None:
1394
+ acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N])
1395
+ else:
1396
+ acc_handle = acc.handle
1397
+ assert acc.type == ret_ty
1398
+
1399
+ # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1400
+ if max_num_imprecise_acc is None:
1401
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1402
+ max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
1403
+ else:
1404
+ max_num_imprecise_acc = 0
1405
+
1406
+ return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc),
1407
+ ret_ty)
1408
+
1409
+
1410
+ # ===----------------------------------------------------------------------===//
1411
+ # Indexing
1412
+ # ===----------------------------------------------------------------------===//
1413
+
1414
+
1415
+ def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
1416
+ condition = cast(condition, tl.int1, builder)
1417
+ if condition.type.is_block():
1418
+ condition, x = broadcast_impl_value(condition, x, builder)
1419
+ x, y = broadcast_impl_value(x, y, builder)
1420
+ condition, x = broadcast_impl_value(condition, x, builder)
1421
+
1422
+ x, y = binary_op_type_checking_impl(x, y, builder, True, True)
1423
+ if not condition.type.is_block():
1424
+ condition, _ = broadcast_impl_value(condition, x, builder)
1425
+ ret_ty = x.type
1426
+ return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
1427
+
1428
+
1429
+ # ===----------------------------------------------------------------------===//
1430
+ # Reduction
1431
+ # ===----------------------------------------------------------------------===
1432
+
1433
+
1434
+ def wrap_tensor(x, scalar_ty, ret_shape):
1435
+ if ret_shape:
1436
+ res_ty = tl.block_type(scalar_ty, ret_shape)
1437
+ else:
1438
+ # 0d-tensor -> scalar
1439
+ res_ty = scalar_ty
1440
+ return tl.tensor(x, res_ty)
1441
+
1442
+
1443
+ def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]:
1444
+ if axis is None:
1445
+ inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs)
1446
+ axis = 0
1447
+ # get result shape
1448
+ shape = inputs[0].type.shape
1449
+ rank = len(shape)
1450
+ assert axis < rank, f"reduction axis must be < inputs rank ({rank})"
1451
+ ret_shape = [s for i, s in enumerate(shape) if i != axis]
1452
+ assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
1453
+
1454
+ reduce_op = builder.create_reduce([t.handle for t in inputs], axis)
1455
+ region_builder_fn(reduce_op)
1456
+ reduce_op.verify()
1457
+
1458
+ return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
1459
+
1460
+
1461
+ # ===----------------------------------------------------------------------===
1462
+ # Associative Scan
1463
+ # ===----------------------------------------------------------------------===
1464
+
1465
+
1466
+ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool,
1467
+ builder: ir.builder) -> Tuple[tl.tensor, ...]:
1468
+ shape = inputs[0].type.shape
1469
+ rank = len(shape)
1470
+
1471
+ assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
1472
+
1473
+ if axis < 0:
1474
+ axis += rank
1475
+
1476
+ for t in inputs:
1477
+ assert t.type.shape == shape, "all scan inputs must have the same shape"
1478
+
1479
+ scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse)
1480
+ region_builder_fn(scan_op)
1481
+ scan_op.verify()
1482
+
1483
+ return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
1484
+
1485
+
1486
+ # ===----------------------------------------------------------------------===
1487
+ # Histogram
1488
+ # ===----------------------------------------------------------------------===
1489
+
1490
+
1491
+ def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
1492
+ assert len(input.shape) == 1, "histogram only supports 1D input"
1493
+ assert input.dtype.is_int(), "histogram only supports integer input"
1494
+ return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, )))
1495
+
1496
+
1497
+ ##
1498
+
1499
+
1500
+ def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
1501
+ if max(1, len(x.shape)) != len(values):
1502
+ raise ValueError("Shape of input to multiple_of does not match the length of values")
1503
+ x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
1504
+ return x
1505
+
1506
+
1507
+ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
1508
+ if len(x.shape) != len(values):
1509
+ raise ValueError("Shape of input to max_contiguous does not match the length of values")
1510
+ x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
1511
+ return x
1512
+
1513
+
1514
+ def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor:
1515
+ if len(x.shape) != len(values):
1516
+ raise ValueError("Shape of input to max_constancy does not match the length of values")
1517
+ x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
1518
+ return x
1519
+
1520
+
1521
+ def debug_barrier(builder: ir.builder) -> tl.tensor:
1522
+ return tl.tensor(builder.create_barrier(), tl.void)
1523
+
1524
+
1525
+ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor:
1526
+ # It makes sense visually for prefix to end in ": "; make it so. Also,
1527
+ # non-empty prefixes should start with " ".
1528
+ if not prefix.endswith(" ") and args:
1529
+ prefix += " "
1530
+ if not prefix.endswith(": ") and args:
1531
+ prefix = prefix[:-1] + ": "
1532
+ if len(prefix) > 2 and not prefix.startswith(" "):
1533
+ prefix = " " + prefix
1534
+
1535
+ new_args = [arg.handle for arg in args]
1536
+ return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void)
1537
+
1538
+
1539
+ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
1540
+ cond_ty = cond.type
1541
+ if not cond_ty.is_block():
1542
+ cond_ty = tl.block_type(cond_ty.scalar, (1, ))
1543
+ cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
1544
+ return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
1545
+
1546
+
1547
+ def _convert_elem_to_ir_value(builder, elem, require_i64):
1548
+ if isinstance(elem, int):
1549
+ elem = tl.constexpr(elem)
1550
+ if isinstance(elem, tl.constexpr):
1551
+ if require_i64:
1552
+ assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
1553
+ f"got a value {elem.value} which is out of the range"
1554
+ return builder.get_int64(elem.value)
1555
+ else:
1556
+ assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
1557
+ f"got a value {elem.value} which is out of the range"
1558
+ return builder.get_int32(elem.value)
1559
+ elif isinstance(elem, tl.tensor):
1560
+ assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
1561
+ assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
1562
+ if elem.dtype != tl.int64 and require_i64:
1563
+ return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed())
1564
+ elif elem.dtype != tl.int32 and not require_i64:
1565
+ assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
1566
+ "add a `.to(tl.int32)` or use regular indexing for 64 bit support"
1567
+ return elem.handle
1568
+ assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
1569
+
1570
+
1571
+ def _convert_to_ir_values(builder, list_like, require_i64=True):
1572
+ if hasattr(list_like, "__iter__"):
1573
+ return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like]
1574
+ return [_convert_elem_to_ir_value(builder, list_like, require_i64)]
1575
+
1576
+
1577
+ def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor:
1578
+ # Convert dynamic arguments to IR values
1579
+ # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
1580
+ shape = _convert_to_ir_values(builder, shape)
1581
+ strides = _convert_to_ir_values(builder, strides)
1582
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1583
+
1584
+ # Check `base` type
1585
+ if not base.type.is_ptr() or base.type.element_ty.is_block():
1586
+ raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
1587
+
1588
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1589
+ if base.type.element_ty == tl.int1:
1590
+ base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder)
1591
+
1592
+ # Check whether `block_shape` is static
1593
+ if not hasattr(block_shape, "__iter__"):
1594
+ block_shape = [block_shape]
1595
+ block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
1596
+ assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \
1597
+ "Expected a list of constant integers (`int32_t` range) in `block_shape`"
1598
+
1599
+ # Check `order`
1600
+ if not hasattr(order, "__iter__"):
1601
+ order = [order]
1602
+ order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
1603
+ assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
1604
+
1605
+ # Must have same length
1606
+ assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \
1607
+ "Expected shape/strides/offsets/block_shape to have the same length"
1608
+
1609
+ # Build value, the type is:
1610
+ # `pointer_type<blocked<shape, element_type>>` in Python
1611
+ # `tt.ptr<tensor<shape, element_type>>` in MLIR
1612
+ handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
1613
+ return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
1614
+
1615
+
1616
+ def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1617
+ # Convert dynamic offsets to IR values
1618
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1619
+
1620
+ # Advanced block pointer type is the same as before
1621
+ return tl.tensor(builder.create_advance(base.handle, offsets), base.type)