triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,1302 @@
1
+ import ast
2
+ import inspect
3
+ import re
4
+ import sys
5
+ import warnings
6
+ import os
7
+ import textwrap
8
+ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
9
+ from .. import language
10
+ from .._C.libtriton import ir
11
+ from ..language import constexpr, tensor, str_to_ty
12
+ from ..runtime.jit import _normalize_ty
13
+ # ideally we wouldn't need any runtime component
14
+ from ..runtime import JITFunction
15
+ from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
16
+ from types import ModuleType
17
+
18
+
19
+ def mangle_ty(ty):
20
+ if ty.is_ptr():
21
+ return 'P' + mangle_ty(ty.element_ty)
22
+ if ty.is_int():
23
+ SIGNED = language.dtype.SIGNEDNESS.SIGNED
24
+ prefix = 'i' if ty.int_signedness == SIGNED else 'u'
25
+ return prefix + str(ty.int_bitwidth)
26
+ if ty.is_floating():
27
+ return str(ty)
28
+ if ty.is_block():
29
+ elt = mangle_ty(ty.scalar)
30
+ shape = '_'.join(map(str, ty.shape))
31
+ return f'{elt}S{shape}S'
32
+ if ty.is_void():
33
+ return 'V'
34
+ assert False, "Unsupported type"
35
+
36
+
37
+ def mangle_fn(name, arg_tys, constants):
38
+ # doesn't mangle ret type, which must be a function of arg tys
39
+ mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
40
+ mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
41
+ mangled_constants = mangled_constants.replace('.', '_d_')
42
+ mangled_constants = mangled_constants.replace("'", '_sq_')
43
+ # [ and ] are not allowed in LLVM identifiers
44
+ mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
45
+ ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
46
+ return ret
47
+
48
+
49
+ def _is_triton_tensor(o: Any) -> bool:
50
+ return isinstance(o, tensor)
51
+
52
+
53
+ def _is_constexpr(o: Any) -> bool:
54
+ return isinstance(o, constexpr)
55
+
56
+
57
+ def _is_triton_scalar(o: Any) -> bool:
58
+ return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
59
+
60
+
61
+ def _is_list_like(o: Any) -> bool:
62
+ return isinstance(o, (list, tuple))
63
+
64
+
65
+ def _unwrap_if_constexpr(o: Any):
66
+ return o.value if isinstance(o, constexpr) else o
67
+
68
+
69
+ def _check_fn_args(node, fn, args):
70
+ if fn.noinline:
71
+ for idx, arg in enumerate(args):
72
+ if not _is_constexpr(arg) and not _is_triton_scalar(arg):
73
+ raise UnsupportedLanguageConstruct(
74
+ fn.src, node,
75
+ f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
76
+ )
77
+
78
+
79
+ def _get_fn_file_line(fn):
80
+ base_fn = fn
81
+ while not isinstance(base_fn, JITFunction):
82
+ base_fn = base_fn.fn
83
+ file_name = base_fn.fn.__code__.co_filename
84
+ lines, begin_line = inspect.getsourcelines(base_fn.fn)
85
+ # Match the following pattern:
86
+ # @triton.autotune(...) <- foo.__code__.co_firstlineno
87
+ # @triton.heuristics(...)
88
+ # @triton.jit
89
+ # def foo(...): <- this line is the first line
90
+ for idx, line in enumerate(lines):
91
+ if line.strip().startswith("def "):
92
+ begin_line += idx
93
+ break
94
+ return file_name, begin_line
95
+
96
+
97
+ _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
98
+
99
+
100
+ class enter_sub_region:
101
+
102
+ def __init__(self, generator):
103
+ self.generator = generator
104
+
105
+ def __enter__(self):
106
+ # record lscope & local_defs in the parent scope
107
+ self.liveins = self.generator.lscope.copy()
108
+ self.prev_defs = self.generator.local_defs.copy()
109
+ self.generator.local_defs = {}
110
+ self.insert_block = self.generator.builder.get_insertion_block()
111
+ self.insert_point = self.generator.builder.get_insertion_point()
112
+ return self.liveins, self.insert_block
113
+
114
+ def __exit__(self, *args, **kwargs):
115
+ self.generator.builder.restore_insertion_point(self.insert_point)
116
+ self.generator.lscope = self.liveins
117
+ self.generator.local_defs = self.prev_defs
118
+
119
+
120
+ # Check if the given syntax node has an "early" return
121
+ class ContainsReturnChecker(ast.NodeVisitor):
122
+
123
+ def __init__(self, gscope):
124
+ self.gscope = gscope
125
+
126
+ def _visit_stmts(self, body) -> bool:
127
+ for s in body:
128
+ if self.visit(s):
129
+ return True
130
+ return False
131
+
132
+ def _visit_function(self, fn) -> bool:
133
+ # Currently we only support JITFunctions defined in the global scope
134
+ if isinstance(fn, JITFunction) and not fn.noinline:
135
+ fn_node = fn.parse()
136
+ return ContainsReturnChecker(self.gscope).visit(fn_node)
137
+ return False
138
+
139
+ def generic_visit(self, node) -> bool:
140
+ ret = False
141
+ for _, value in ast.iter_fields(node):
142
+ if isinstance(value, list):
143
+ for item in value:
144
+ if isinstance(item, ast.AST):
145
+ ret = ret or self.visit(item)
146
+ elif isinstance(value, ast.AST):
147
+ ret = ret or self.visit(value)
148
+ return ret
149
+
150
+ def visit_Attribute(self, node: ast.Attribute) -> bool:
151
+ # If the left part is a name, it's possible that
152
+ # we call triton native function or a jit function from another module.
153
+ # If the left part is not a name, it must return a tensor or a constexpr
154
+ # whose methods do not contain return statements
155
+ # e.g., (tl.load(x)).to(y)
156
+ # So we only check if the expressions within value have return or not
157
+ if isinstance(node.value, ast.Name):
158
+ if node.value.id in self.gscope:
159
+ value = self.gscope[node.value.id]
160
+ fn = getattr(value, node.attr)
161
+ return self._visit_function(fn)
162
+ return False
163
+ return self.visit(node.value)
164
+
165
+ def visit_Name(self, node: ast.Name) -> bool:
166
+ if type(node.ctx) == ast.Store:
167
+ return False
168
+ if node.id in self.gscope:
169
+ fn = self.gscope[node.id]
170
+ return self._visit_function(fn)
171
+ return False
172
+
173
+ def visit_Return(self, node: ast.Return) -> bool:
174
+ return True
175
+
176
+ def visit_Assign(self, node: ast.Assign) -> bool:
177
+ # There couldn't be an early return
178
+ # x = ...
179
+ return False
180
+
181
+ def visit_AugAssign(self, node: ast.AugAssign) -> bool:
182
+ # There couldn't be an early return
183
+ # x += ...
184
+ return False
185
+
186
+ def visit_Module(self, node: ast.Module) -> bool:
187
+ return self._visit_stmts(node.body)
188
+
189
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
190
+ return self._visit_stmts(node.body)
191
+
192
+ def visit_If(self, node: ast.If) -> bool:
193
+ # TODO: optimize the following case in which we actually don't have
194
+ # a return when static_cond is false:
195
+ # if dynamic_cond
196
+ # if static_cond
197
+ # func_with_return
198
+ # else
199
+ # func_without_return
200
+ ret = self._visit_stmts(node.body)
201
+ if node.orelse:
202
+ ret = ret or self._visit_stmts(node.orelse)
203
+ return ret
204
+
205
+ def visit_IfExp(self, node: ast.IfExp) -> bool:
206
+ return self.visit(node.body) or self.visit(node.orelse)
207
+
208
+ def visit_Call(self, node: ast.Call) -> bool:
209
+ return self.visit(node.func)
210
+
211
+
212
+ class CodeGenerator(ast.NodeVisitor):
213
+
214
+ def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
215
+ codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None,
216
+ noinline=False, file_name: Optional[str] = None, begin_line=0):
217
+ self.context = context
218
+ self.builder = ir.builder(context)
219
+ self.file_name = file_name
220
+ # node.lineno starts from 1, so we need to subtract 1
221
+ self.begin_line = begin_line - 1
222
+ self.builder.set_loc(file_name, begin_line, 0)
223
+ self.builder.options = options
224
+ # dict of functions provided by the backend. Below are the list of possible functions:
225
+ # Convert custom types not natively supported on HW.
226
+ # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
227
+ self.builder.codegen_fns = codegen_fns
228
+ self.module = self.builder.create_module() if module is None else module
229
+ self.function_ret_types = {} if function_types is None else function_types
230
+ self.prototype = prototype
231
+ self.gscope = gscope
232
+ self.lscope = dict()
233
+ self.attributes = attributes
234
+ self.constants = constants
235
+ self.jit_fn = jit_fn
236
+ self.function_name = function_name
237
+ self.is_kernel = is_kernel
238
+ self.cur_node = None
239
+ self.debug = options.debug if debug is None else debug
240
+ self.noinline = noinline
241
+ self.scf_stack = []
242
+ self.ret_type = None
243
+ # SSA-construction
244
+ # name => language.tensor
245
+ self.local_defs: Dict[str, tensor] = {}
246
+ self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
247
+ self.fn = None
248
+ # Are we currently visiting an ast.arg's default value? These have some
249
+ # special handling.
250
+ self.visiting_arg_default_value = False
251
+
252
+ builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)}
253
+ builtin_namespace.update((
254
+ ('print', language.core.device_print),
255
+ ('min', language.minimum),
256
+ ('max', language.maximum),
257
+ ))
258
+
259
+ def _unsupported(self, node, message):
260
+ return UnsupportedLanguageConstruct(self.jit_fn.src, node, message)
261
+
262
+ def _is_constexpr_global(self, name):
263
+ absent_marker = object()
264
+ val = self.gscope.get(name, absent_marker)
265
+ if val is absent_marker:
266
+ return False
267
+
268
+ if _is_constexpr(val):
269
+ return True
270
+
271
+ if a := self.gscope.get("__annotations__", {}).get(name):
272
+ return _normalize_ty(a) == "constexpr"
273
+
274
+ return False
275
+
276
+ def _define_name_lookup(self):
277
+
278
+ def local_lookup(name: str, absent):
279
+ # this needs to be re-fetched from `self` every time, because it gets switched occasionally
280
+ return self.lscope.get(name, absent)
281
+
282
+ def global_lookup(name: str, absent):
283
+ val = self.gscope.get(name, absent)
284
+ # The high-level rule is that only constexpr globals are allowed.
285
+ # But actually a bunch of other things, such as module imports, are
286
+ # technically Python globals. We have to allow these too!
287
+ if (val is absent #
288
+ or name in self.builtin_namespace #
289
+ or type(val) == ModuleType #
290
+ or isinstance(val, JITFunction) #
291
+ or getattr(val, "__triton_builtin__", False) #
292
+ or getattr(val, "__module__", "").startswith("triton.language") #
293
+ or isinstance(val, language.dtype) #
294
+ or self._is_constexpr_global(name) #
295
+ # Allow accesses to globals while visiting an ast.arg
296
+ # because you should be able to do
297
+ # @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
298
+ or self.visiting_arg_default_value #
299
+ or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"):
300
+ return val
301
+ raise NameError(
302
+ textwrap.dedent(f"""\
303
+ Cannot access global variable {name} from within @jit'ed
304
+ function. Triton kernels can only access global variables that
305
+ are annotated as constexpr (`x: triton.language.constexpr = 42`
306
+ or `x = triton.language.constexpr(42)`). Alternatively, set the
307
+ envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
308
+ promise to support this forever.""").replace("\n", " "))
309
+
310
+ absent_marker = object()
311
+
312
+ def name_lookup(name: str) -> Any:
313
+ absent = absent_marker
314
+ for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get:
315
+ value = lookup_function(name, absent)
316
+ if value is not absent:
317
+ return value
318
+ raise NameError(f'{name} is not defined')
319
+
320
+ return name_lookup
321
+
322
+ def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
323
+ ''' This function:
324
+ called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
325
+ 1. record local defined name (FIXME: should consider control flow)
326
+ 2. store tensor in self.lvalue
327
+ '''
328
+ self.lscope[name] = value
329
+ self.local_defs[name] = value
330
+
331
+ def _get_insertion_point_and_loc(self):
332
+ # XXX: this is a hack to get the location of the insertion point.
333
+ # The insertion point's location could be invalid sometimes,
334
+ # so we need to explicitly set the location
335
+ loc = self.builder.get_loc()
336
+ ip = self.builder.get_insertion_point()
337
+ return ip, loc
338
+
339
+ def _set_insertion_point_and_loc(self, ip, loc):
340
+ self.builder.restore_insertion_point(ip)
341
+ self.builder.set_loc(loc)
342
+
343
+ #
344
+ # AST visitor
345
+ #
346
+ def visit_compound_statement(self, stmts):
347
+ # Ensure that stmts is iterable
348
+ if not _is_list_like(stmts):
349
+ stmts = [stmts]
350
+ for stmt in stmts:
351
+ self.visit(stmt)
352
+
353
+ # Stop parsing as soon as we hit a `return` statement; everything
354
+ # after this is dead code.
355
+ if isinstance(stmt, ast.Return):
356
+ break
357
+
358
+ def visit_Module(self, node):
359
+ ast.NodeVisitor.generic_visit(self, node)
360
+
361
+ def visit_List(self, node):
362
+ ctx = self.visit(node.ctx)
363
+ assert ctx is None
364
+ elts = [self.visit(elt) for elt in node.elts]
365
+ return elts
366
+
367
+ # By design, only non-kernel functions can return
368
+ def visit_Return(self, node):
369
+ ret_value = self.visit(node.value)
370
+ # ret_block = self.builder.create_block()
371
+ # post_ret_block = self.builder.create_block()
372
+ # self.builder.create_branch(ret_block)
373
+ # self.builder.set_insertion_point_to_end(ret_block)
374
+ if ret_value is None:
375
+ self.builder.ret([])
376
+ ret_ty = language.void
377
+ elif isinstance(ret_value, tuple):
378
+ ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value]
379
+ ret_types = [v.type for v in ret_values]
380
+ self.builder.ret([v.handle for v in ret_values])
381
+ ret_ty = tuple(ret_types)
382
+ else:
383
+ ret = language.core._to_tensor(ret_value, self.builder)
384
+ self.builder.ret([ret.handle])
385
+ ret_ty = ret.type
386
+ # self.builder.create_branch(post_ret_block)
387
+ # self.builder.set_insertion_point_to_end(post_ret_block)
388
+
389
+ if self.ret_type is None:
390
+ self.ret_type = ret_ty
391
+ elif self.ret_type != ret_ty:
392
+ raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}')
393
+
394
+ def visit_FunctionDef(self, node):
395
+ arg_names, kwarg_names = self.visit(node.args)
396
+ if self.fn:
397
+ raise self._unsupported(node, "nested function definition is not supported.")
398
+ # initialize defaults
399
+ for i, default_value in enumerate(node.args.defaults):
400
+ arg_node = node.args.args[-i - 1]
401
+ annotation = arg_node.annotation
402
+ name = arg_node.arg
403
+ st_target = ast.Name(id=name, ctx=ast.Store())
404
+ if annotation is None:
405
+ init_node = ast.Assign(targets=[st_target], value=default_value)
406
+ else:
407
+ init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
408
+
409
+ try:
410
+ assert not self.visiting_arg_default_value
411
+ self.visiting_arg_default_value = True
412
+ self.visit(init_node)
413
+ finally:
414
+ self.visiting_arg_default_value = False
415
+
416
+ # initialize function
417
+ visibility = "public" if self.is_kernel else "private"
418
+ self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
419
+ self.prototype.to_ir(self.builder), visibility, self.noinline)
420
+ self.module.push_back(self.fn)
421
+ entry = self.fn.add_entry_block()
422
+ arg_values = []
423
+ idx = 0
424
+ for i, arg_name in enumerate(arg_names):
425
+ if i in self.constants:
426
+ cst = self.constants[i]
427
+ if not _is_constexpr(cst):
428
+ cst = constexpr(self.constants[i])
429
+ arg_values.append(cst)
430
+ continue
431
+ else:
432
+ if i in self.attributes:
433
+ for name, value in self.attributes[i]:
434
+ self.fn.set_arg_attr(idx, name, value)
435
+ arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
436
+ idx += 1
437
+
438
+ insert_pt = self.builder.get_insertion_block()
439
+ for arg_name, arg_value in zip(arg_names, arg_values):
440
+ self.set_value(arg_name, arg_value)
441
+ self.builder.set_insertion_point_to_start(entry)
442
+ # visit function body
443
+ self.visit_compound_statement(node.body)
444
+ # finalize function
445
+ if self.ret_type is None or self.ret_type == language.void:
446
+ self.ret_type = language.void
447
+ self.builder.ret([])
448
+ else:
449
+ # update return type
450
+ if isinstance(self.ret_type, tuple):
451
+ self.prototype.ret_types = list(self.ret_type)
452
+ self.fn.reset_type(self.prototype.to_ir(self.builder))
453
+ else:
454
+ self.prototype.ret_types = [self.ret_type]
455
+ self.fn.reset_type(self.prototype.to_ir(self.builder))
456
+ if insert_pt:
457
+ self.builder.set_insertion_point_to_end(insert_pt)
458
+ # Remove dead code
459
+ self.fn.finalize()
460
+
461
+ def visit_arguments(self, node):
462
+ arg_names = []
463
+ for arg in node.args:
464
+ arg_names += [self.visit(arg)]
465
+ kwarg_names = self.visit(node.kwarg)
466
+ return arg_names, kwarg_names
467
+
468
+ def visit_arg(self, node):
469
+ ast.NodeVisitor.generic_visit(self, node)
470
+ return node.arg
471
+
472
+ def visit_AnnAssign(self, node):
473
+ # extract attributes
474
+ annotation = self.visit(node.annotation)
475
+ target = self.visit(node.target)
476
+ value = self.visit(node.value)
477
+ # constexpr
478
+ if annotation == constexpr:
479
+ if target in self.lscope:
480
+ raise ValueError(f'{target} is already defined.'
481
+ f' constexpr cannot be reassigned.')
482
+ if not _is_constexpr(value):
483
+ value = constexpr(value)
484
+ self.lscope[target] = value
485
+ return self.lscope[target]
486
+ # default: call visit_Assign
487
+ return self.visit_Assign(node)
488
+
489
+ def visit_Assign(self, node):
490
+ _names = []
491
+ for target in node.targets:
492
+ _names += [self.visit(target)]
493
+ if len(_names) > 1:
494
+ raise self._unsupported(node, "simultaneous multiple assignment is not supported.")
495
+ names = _names[0]
496
+ values = self.visit(node.value)
497
+ if not _is_list_like(names):
498
+ names = [names]
499
+ if not _is_list_like(values):
500
+ values = [values]
501
+ native_nontensor_types = (language.dtype, )
502
+ for name, value in zip(names, values):
503
+ # by default, constexpr are assigned into python variable
504
+ value = _unwrap_if_constexpr(value)
505
+ if value is not None and \
506
+ not _is_triton_tensor(value) and \
507
+ not isinstance(value, native_nontensor_types):
508
+ value = language.core._to_tensor(value, self.builder)
509
+ self.set_value(name, value)
510
+
511
+ def visit_AugAssign(self, node):
512
+ name = node.target.id
513
+ lhs = ast.Name(id=name, ctx=ast.Load())
514
+ rhs = ast.BinOp(lhs, node.op, node.value)
515
+ assign = ast.Assign(targets=[node.target], value=rhs)
516
+ self.visit(assign)
517
+ return self.dereference_name(name)
518
+
519
+ def visit_Name(self, node):
520
+ if type(node.ctx) == ast.Store:
521
+ return node.id
522
+ return self.dereference_name(node.id)
523
+
524
+ def visit_Store(self, node):
525
+ ast.NodeVisitor.generic_visit(self, node)
526
+
527
+ def visit_Load(self, node):
528
+ ast.NodeVisitor.generic_visit(self, node)
529
+
530
+ def visit_Tuple(self, node):
531
+ args = [self.visit(x) for x in node.elts]
532
+ return tuple(args)
533
+
534
+ def _apply_binary_method(self, method_name, lhs, rhs):
535
+ # TODO: raise something meaningful if getattr fails below, esp for reverse method
536
+ if _is_triton_tensor(lhs):
537
+ return getattr(lhs, method_name)(rhs, _builder=self.builder)
538
+ if _is_triton_tensor(rhs):
539
+ reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
540
+ return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder)
541
+ return getattr(lhs, method_name)(rhs)
542
+
543
+ def visit_BinOp(self, node):
544
+ lhs = self.visit(node.left)
545
+ rhs = self.visit(node.right)
546
+ method_name = self._method_name_for_bin_op.get(type(node.op))
547
+ if method_name is None:
548
+ raise self._unsupported(node,
549
+ "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
550
+ return self._apply_binary_method(method_name, lhs, rhs)
551
+
552
+ _method_name_for_bin_op: Dict[Type[ast.operator], str] = {
553
+ ast.Add: '__add__',
554
+ ast.Sub: '__sub__',
555
+ ast.Mult: '__mul__',
556
+ ast.Div: '__truediv__',
557
+ ast.FloorDiv: '__floordiv__',
558
+ ast.Mod: '__mod__',
559
+ ast.Pow: '__pow__',
560
+ ast.LShift: '__lshift__',
561
+ ast.RShift: '__rshift__',
562
+ ast.BitAnd: '__and__',
563
+ ast.BitOr: '__or__',
564
+ ast.BitXor: '__xor__',
565
+ }
566
+
567
+ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
568
+ # then block
569
+ self.builder.set_insertion_point_to_start(then_block)
570
+ self.visit_compound_statement(node.body)
571
+ then_block = self.builder.get_insertion_block()
572
+ then_defs = self.local_defs.copy()
573
+ # else block
574
+ else_defs = {}
575
+ if node.orelse:
576
+ self.builder.set_insertion_point_to_start(else_block)
577
+ self.lscope = liveins.copy()
578
+ self.local_defs = {}
579
+ self.visit_compound_statement(node.orelse)
580
+ else_defs = self.local_defs.copy()
581
+ else_block = self.builder.get_insertion_block()
582
+
583
+ # update block arguments
584
+ names = []
585
+ ret_types = []
586
+ ir_ret_types = []
587
+ # variables in livein whose value is updated in `if`
588
+ for name in liveins:
589
+ # check type
590
+ for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
591
+ if name in defs:
592
+ assert defs[name].type == liveins[name].type, \
593
+ f'initial value for `{name}` is of type {liveins[name].type}, '\
594
+ f'but the {block_name} block redefines it as {defs[name].type}'
595
+ if name in then_defs or name in else_defs:
596
+ names.append(name)
597
+ ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
598
+ ir_ret_types.append(then_defs[name].handle.get_type() if name in
599
+ then_defs else else_defs[name].handle.get_type())
600
+ # variable defined in then but not in else
601
+ if name in then_defs and name not in else_defs:
602
+ else_defs[name] = liveins[name]
603
+ # variable defined in else but not in then
604
+ if name in else_defs and name not in then_defs:
605
+ then_defs[name] = liveins[name]
606
+ # variables that are both in then and else but not in liveins
607
+ # TODO: could probably be cleaned up
608
+ for name in then_defs.keys() & else_defs.keys():
609
+ if name in names:
610
+ continue
611
+ then_ty = then_defs[name].type
612
+ else_ty = else_defs[name].type
613
+ assert then_ty == else_ty, \
614
+ f'mismatched type for {name} between then block ({then_ty}) '\
615
+ f'and else block ({else_ty})'
616
+ names.append(name)
617
+ ret_types.append(then_ty)
618
+ ir_ret_types.append(then_defs[name].handle.get_type())
619
+
620
+ return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types
621
+
622
+ def visit_if_top_level(self, cond, node):
623
+ has_endif_block = True
624
+ with enter_sub_region(self) as sr:
625
+ liveins, ip_block = sr
626
+ then_block = self.builder.create_block()
627
+ else_block = self.builder.create_block()
628
+ # create basic-block after conditional
629
+ endif_block = self.builder.create_block()
630
+ # create branch
631
+ self.builder.set_insertion_point_to_end(ip_block)
632
+ self.builder.create_cond_branch(cond.handle, then_block, else_block)
633
+ # visit then and else blocks
634
+ then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \
635
+ self.visit_then_else_blocks(node, liveins, then_block, else_block)
636
+ # then terminator
637
+ self.builder.set_insertion_point_to_end(then_block)
638
+ if then_block.has_return() and else_block.has_return():
639
+ has_endif_block = False
640
+ endif_block.erase()
641
+ if not then_block.has_terminator() and has_endif_block:
642
+ self.builder.create_branch(endif_block, [then_defs[n].handle for n in names])
643
+ # else terminator
644
+ self.builder.set_insertion_point_to_end(else_block)
645
+ if not else_block.has_terminator() and has_endif_block:
646
+ self.builder.create_branch(endif_block, [else_defs[n].handle for n in names])
647
+ if has_endif_block:
648
+ for ty in ir_ret_types:
649
+ endif_block.add_argument(ty)
650
+ if has_endif_block:
651
+ # change block
652
+ self.builder.set_insertion_point_to_start(endif_block)
653
+ # update value
654
+ for i, name in enumerate(names):
655
+ new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
656
+ self.set_value(name, new_tensor)
657
+
658
+ # TODO: refactor
659
+ def visit_if_scf(self, cond, node):
660
+ with enter_sub_region(self) as sr:
661
+ liveins, _ = sr
662
+ ip, last_loc = self._get_insertion_point_and_loc()
663
+ then_block = self.builder.create_block()
664
+ else_block = self.builder.create_block() if node.orelse else None
665
+ then_defs, else_defs, then_block, else_block, names, ret_types, _ = \
666
+ self.visit_then_else_blocks(node, liveins, then_block, else_block)
667
+ # create if op
668
+ self._set_insertion_point_and_loc(ip, last_loc)
669
+ if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
670
+ then_block.merge_block_before(if_op.get_then_block())
671
+ self.builder.set_insertion_point_to_end(if_op.get_then_block())
672
+ if len(names) > 0:
673
+ self.builder.create_yield_op([then_defs[n].handle for n in names])
674
+ if not node.orelse:
675
+ else_block = if_op.get_else_block()
676
+ else:
677
+ else_block.merge_block_before(if_op.get_else_block())
678
+ self.builder.set_insertion_point_to_end(if_op.get_else_block())
679
+ if len(names) > 0:
680
+ self.builder.create_yield_op([else_defs[n].handle for n in names])
681
+ # update values
682
+ for i, name in enumerate(names):
683
+ new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i])
684
+ self.set_value(name, new_tensor)
685
+
686
+ def visit_If(self, node):
687
+ cond = self.visit(node.test)
688
+ if _is_triton_tensor(cond):
689
+ cond = cond.to(language.int1, _builder=self.builder)
690
+ contains_return = ContainsReturnChecker(self.gscope).visit(node)
691
+ if self.scf_stack and contains_return:
692
+ raise self._unsupported(
693
+ node, "Cannot have `return` statements inside `while` or `for` statements in triton "
694
+ "(note that this also applies to `return` statements that are inside functions "
695
+ "transitively called from within `while`/`for` statements)")
696
+ elif self.scf_stack or not contains_return:
697
+ self.visit_if_scf(cond, node)
698
+ else:
699
+ self.visit_if_top_level(cond, node)
700
+ else:
701
+ cond = _unwrap_if_constexpr(cond)
702
+ # not isinstance - we insist the real thing, no subclasses and no ducks
703
+ if type(cond) not in _condition_types:
704
+ raise self._unsupported(
705
+ node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
706
+ ', '.join(_.__name__ for _ in _condition_types),
707
+ type(cond).__name__))
708
+ if cond:
709
+ self.visit_compound_statement(node.body)
710
+ else:
711
+ self.visit_compound_statement(node.orelse)
712
+
713
+ def visit_IfExp(self, node):
714
+ cond = self.visit(node.test)
715
+ if _is_triton_tensor(cond):
716
+ cond = cond.to(language.int1, _builder=self.builder)
717
+ # TODO: Deal w/ more complicated return types (e.g tuple)
718
+ with enter_sub_region(self):
719
+ ip, last_loc = self._get_insertion_point_and_loc()
720
+
721
+ then_block = self.builder.create_block()
722
+ self.builder.set_insertion_point_to_start(then_block)
723
+ then_val = language.core._to_tensor(self.visit(node.body), self.builder)
724
+ then_block = self.builder.get_insertion_block()
725
+
726
+ else_block = self.builder.create_block()
727
+ self.builder.set_insertion_point_to_start(else_block)
728
+ # do not need to reset lscope since
729
+ # ternary expressions cannot define new variables
730
+ else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
731
+ else_block = self.builder.get_insertion_block()
732
+
733
+ self._set_insertion_point_and_loc(ip, last_loc)
734
+
735
+ assert then_val.type == else_val.type, \
736
+ f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
737
+ ret_type = then_val.type
738
+
739
+ ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
740
+ if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
741
+ then_block.merge_block_before(if_op.get_then_block())
742
+ if ret_type_ir:
743
+ self.builder.set_insertion_point_to_end(if_op.get_then_block())
744
+ self.builder.create_yield_op([then_val.handle])
745
+
746
+ self.builder.set_insertion_point_to_end(if_op.get_then_block())
747
+ else_block.merge_block_before(if_op.get_else_block())
748
+ if ret_type_ir:
749
+ self.builder.set_insertion_point_to_end(if_op.get_else_block())
750
+ self.builder.create_yield_op([else_val.handle])
751
+ return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
752
+ else:
753
+ cond = _unwrap_if_constexpr(cond)
754
+
755
+ # not isinstance - we insist the real thing, no subclasses and no ducks
756
+ if type(cond) not in _condition_types:
757
+ raise self._unsupported(
758
+ node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
759
+ ', '.join(_.__name__ for _ in _condition_types),
760
+ type(cond).__name__))
761
+ if cond:
762
+ return self.visit(node.body)
763
+ else:
764
+ return self.visit(node.orelse)
765
+
766
+ def visit_Pass(self, node):
767
+ pass
768
+
769
+ def visit_Compare(self, node):
770
+ if not (len(node.comparators) == 1 and len(node.ops) == 1):
771
+ raise self._unsupported(node, "simultaneous multiple comparison is not supported")
772
+ lhs = self.visit(node.left)
773
+ rhs = self.visit(node.comparators[0])
774
+ lhs_value = _unwrap_if_constexpr(lhs)
775
+ rhs_value = _unwrap_if_constexpr(rhs)
776
+ if type(node.ops[0]) == ast.Is:
777
+ return constexpr(lhs_value is rhs_value)
778
+ if type(node.ops[0]) == ast.IsNot:
779
+ return constexpr(lhs_value is not rhs_value)
780
+ method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
781
+ if method_name is None:
782
+ raise self._unsupported(
783
+ node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
784
+ return self._apply_binary_method(method_name, lhs, rhs)
785
+
786
+ _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
787
+ ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
788
+ }
789
+
790
+ def visit_UnaryOp(self, node):
791
+ operand = self.visit(node.operand)
792
+ fn = self._method_name_for_unary_op.get(type(node.op))
793
+ if fn is None:
794
+ raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
795
+ if _is_triton_tensor(operand):
796
+ return getattr(operand, fn)(_builder=self.builder)
797
+ try:
798
+ return getattr(operand, fn)()
799
+ except AttributeError:
800
+ raise self._unsupported(
801
+ node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
802
+
803
+ _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
804
+ ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
805
+ }
806
+
807
+ def visit_While(self, node):
808
+ with enter_sub_region(self) as sr:
809
+ liveins, insert_block = sr
810
+ ip, last_loc = self._get_insertion_point_and_loc()
811
+
812
+ # loop body (the after region)
813
+ # loop_block = self.builder.create_block()
814
+ dummy = self.builder.create_block()
815
+ self.builder.set_insertion_point_to_start(dummy)
816
+ self.scf_stack.append(node)
817
+ self.visit_compound_statement(node.body)
818
+ self.scf_stack.pop()
819
+ loop_defs = self.local_defs
820
+ dummy.erase()
821
+
822
+ # collect loop-carried values
823
+ names = []
824
+ ret_types = []
825
+ init_args = []
826
+ for name in loop_defs:
827
+ if name in liveins:
828
+ # We should not def new constexpr
829
+ assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop'
830
+ assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop'
831
+ assert loop_defs[name].type == liveins[name].type, \
832
+ f'Loop-carried variable {name} has initial type {liveins[name].type} '\
833
+ f'but is re-assigned to {loop_defs[name].type} in loop! '\
834
+ f'Please make sure that the type stays consistent.'
835
+
836
+ # these are loop-carried values
837
+ names.append(name)
838
+ ret_types.append(loop_defs[name].type)
839
+ init_args.append(liveins[name])
840
+
841
+ self._set_insertion_point_and_loc(ip, last_loc)
842
+ while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
843
+ [arg.handle for arg in init_args])
844
+ # merge the condition region
845
+ before_block = self.builder.create_block_with_parent(while_op.get_before(),
846
+ [ty.to_ir(self.builder) for ty in ret_types])
847
+ self.builder.set_insertion_point_to_start(before_block)
848
+ for i, name in enumerate(names):
849
+ self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i])
850
+ self.local_defs[name] = self.lscope[name]
851
+ cond = self.visit(node.test)
852
+ self.builder.set_insertion_point_to_end(before_block)
853
+ # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
854
+ self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
855
+ # merge the loop body
856
+ after_block = self.builder.create_block_with_parent(while_op.get_after(),
857
+ [ty.to_ir(self.builder) for ty in ret_types])
858
+
859
+ # generate loop body
860
+ self.builder.set_insertion_point_to_start(after_block)
861
+ for i, name in enumerate(names):
862
+ self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i])
863
+ self.local_defs[name] = self.lscope[name]
864
+ self.scf_stack.append(node)
865
+ self.visit_compound_statement(node.body)
866
+ self.scf_stack.pop()
867
+ loop_defs = self.local_defs
868
+ yields = []
869
+ for name in loop_defs:
870
+ if name in liveins:
871
+ yields.append(loop_defs[name])
872
+ self.builder.create_yield_op([y.handle for y in yields])
873
+
874
+ # WhileOp defines new values, update the symbol table (lscope, local_defs)
875
+ for i, name in enumerate(names):
876
+ new_def = language.core.tensor(while_op.get_result(i), ret_types[i])
877
+ self.lscope[name] = new_def
878
+ self.local_defs[name] = new_def
879
+
880
+ for stmt in node.orelse:
881
+ assert False, "Not implemented"
882
+ ast.NodeVisitor.generic_visit(self, stmt)
883
+
884
+ def visit_Subscript(self, node):
885
+ assert node.ctx.__class__.__name__ == "Load"
886
+ lhs = self.visit(node.value)
887
+ slices = self.visit(node.slice)
888
+ if _is_triton_tensor(lhs):
889
+ return lhs.__getitem__(slices, _builder=self.builder)
890
+ return lhs[slices]
891
+
892
+ def visit_ExtSlice(self, node):
893
+ return [self.visit(dim) for dim in node.dims]
894
+
895
+ def visit_For(self, node):
896
+ IteratorClass = self.visit(node.iter.func)
897
+ iter_args = [self.visit(arg) for arg in node.iter.args]
898
+ iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords)
899
+ if IteratorClass == language.static_range:
900
+ iterator = IteratorClass(*iter_args, **iter_kwargs)
901
+ static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
902
+ for i in static_range:
903
+ self.lscope[node.target.id] = constexpr(i)
904
+ self.visit_compound_statement(node.body)
905
+ for stmt in node.orelse:
906
+ ast.NodeVisitor.generic_visit(self, stmt)
907
+ return
908
+ num_stages = None
909
+ if IteratorClass is language.range:
910
+ iterator = IteratorClass(*iter_args, **iter_kwargs)
911
+ # visit iterator arguments
912
+ # note: only `range` iterator is supported now
913
+ # collect lower bound (lb), upper bound (ub), and step
914
+ lb = iterator.start
915
+ ub = iterator.end
916
+ step = iterator.step
917
+ num_stages = iterator.num_stages
918
+ elif IteratorClass is range:
919
+ # visit iterator arguments
920
+ # note: only `range` iterator is supported now
921
+ # collect lower bound (lb), upper bound (ub), and step
922
+ lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
923
+ ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
924
+ step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
925
+ else:
926
+ raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
927
+ # handle negative constant step (not supported by scf.for in MLIR)
928
+ negative_step = False
929
+ if _is_constexpr(step) and step.value < 0:
930
+ step = constexpr(-step.value)
931
+ negative_step = True
932
+ lb, ub = ub, lb
933
+ lb = language.core._to_tensor(lb, self.builder)
934
+ ub = language.core._to_tensor(ub, self.builder)
935
+ step = language.core._to_tensor(step, self.builder)
936
+ # induction variable type
937
+ if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
938
+ raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
939
+ iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
940
+ iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype)
941
+ iv_ir_type = iv_type.to_ir(self.builder)
942
+ iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
943
+ # lb/ub/step might be constexpr, we need to cast them to tensor
944
+ lb = lb.handle
945
+ ub = ub.handle
946
+ step = step.handle
947
+ # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
948
+ lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
949
+ ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
950
+ step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
951
+ # Create placeholder for the loop induction variable
952
+ iv = self.builder.create_undef(iv_ir_type)
953
+ self.set_value(node.target.id, language.core.tensor(iv, iv_type))
954
+
955
+ with enter_sub_region(self) as sr:
956
+ liveins, insert_block = sr
957
+ ip, last_loc = self._get_insertion_point_and_loc()
958
+
959
+ # create loop body block
960
+ block = self.builder.create_block()
961
+ self.builder.set_insertion_point_to_start(block)
962
+ # dry visit loop body
963
+ self.scf_stack.append(node)
964
+ self.visit_compound_statement(node.body)
965
+ self.scf_stack.pop()
966
+ block.erase()
967
+
968
+ # If a variable (name) is defined in both its parent & itself, then it's
969
+ # a loop-carried variable. (They must be of the same type)
970
+ init_args = []
971
+ yields = []
972
+ names = []
973
+ for name in self.local_defs:
974
+ if name in liveins:
975
+ assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
976
+ assert _is_triton_tensor(liveins[name])
977
+ assert self.local_defs[name].type == liveins[name].type, \
978
+ f'Loop-carried variable {name} has initial type {liveins[name].type} '\
979
+ f'but is re-assigned to {self.local_defs[name].type} in loop! '\
980
+ f'Please make sure that the type stays consistent.'
981
+
982
+ names.append(name)
983
+ init_args.append(language.core._to_tensor(liveins[name], self.builder))
984
+ yields.append(language.core._to_tensor(self.local_defs[name], self.builder))
985
+
986
+ # create ForOp
987
+ self._set_insertion_point_and_loc(ip, last_loc)
988
+ for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
989
+ if num_stages is not None:
990
+ for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
991
+
992
+ self.scf_stack.append(node)
993
+ self.builder.set_insertion_point_to_start(for_op.get_body(0))
994
+ # reset local scope to not pick up local defs from the previous dry run.
995
+ self.lscope = liveins.copy()
996
+ self.local_defs = {}
997
+ for i, name in enumerate(names):
998
+ self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type))
999
+ self.visit_compound_statement(node.body)
1000
+ self.scf_stack.pop()
1001
+ yields = []
1002
+ for name in self.local_defs:
1003
+ if name in liveins:
1004
+ yields.append(language.core._to_tensor(self.local_defs[name], self.builder))
1005
+
1006
+ # create YieldOp
1007
+ if len(yields) > 0:
1008
+ self.builder.create_yield_op([y.handle for y in yields])
1009
+ for_op_region = for_op.get_body(0).get_parent()
1010
+ assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
1011
+
1012
+ # update induction variable with actual value, and replace all uses
1013
+ self.builder.set_insertion_point_to_start(for_op.get_body(0))
1014
+ iv = for_op.get_induction_var()
1015
+ if negative_step:
1016
+ iv = self.builder.create_sub(ub, iv)
1017
+ iv = self.builder.create_add(iv, lb)
1018
+ self.lscope[node.target.id].handle.replace_all_uses_with(iv)
1019
+ self.set_value(node.target.id, language.core.tensor(iv, iv_type))
1020
+
1021
+ # update lscope & local_defs (ForOp defines new values)
1022
+ for i, name in enumerate(names):
1023
+ self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type))
1024
+
1025
+ for stmt in node.orelse:
1026
+ assert False, "Don't know what to do with else after for"
1027
+ ast.NodeVisitor.generic_visit(self, stmt)
1028
+
1029
+ def visit_Slice(self, node):
1030
+ lower = self.visit(node.lower)
1031
+ upper = self.visit(node.upper)
1032
+ step = self.visit(node.step)
1033
+ return slice(lower, upper, step)
1034
+
1035
+ def visit_Index(self, node):
1036
+ return self.visit(node.value)
1037
+
1038
+ def visit_keyword(self, node) -> Tuple[str, Any]:
1039
+ return node.arg, self.visit(node.value)
1040
+
1041
+ def visit_Assert(self, node) -> Any:
1042
+ if not self.debug:
1043
+ return
1044
+ test = self.visit(node.test)
1045
+ msg = self.visit(node.msg) if node.msg is not None else ""
1046
+ # Convert assert to triton's device_assert which happens on the device
1047
+ return language.core.device_assert(test, msg, _builder=self.builder)
1048
+
1049
+ def call_JitFunction(self, fn: JITFunction, args, kwargs):
1050
+ args = inspect.getcallargs(fn.fn, *args, **kwargs)
1051
+ args = [args[name] for name in fn.arg_names]
1052
+ args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
1053
+ # generate function def
1054
+ attributes = dict()
1055
+ constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
1056
+ constants = {i: args[i] for i in constexprs}
1057
+ # generate call
1058
+ args = [None if i in constexprs else arg for i, arg in enumerate(args)]
1059
+ arg_vals = [arg.handle for arg in args if arg is not None]
1060
+ arg_types = [arg.type for arg in args if arg is not None]
1061
+ fn_name = mangle_fn(fn.__name__, arg_types, constants)
1062
+ # generate function def if necessary
1063
+ if not self.module.has_function(fn_name):
1064
+ prototype = language.function_type([], arg_types)
1065
+ gscope = fn.__globals__
1066
+ # If the callee is not set, we use the same debug setting as the caller
1067
+ file_name, begin_line = _get_fn_file_line(fn)
1068
+ debug = self.debug if fn.debug is None else fn.debug
1069
+ generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
1070
+ jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
1071
+ noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
1072
+ options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug)
1073
+ try:
1074
+ generator.visit(fn.parse())
1075
+ except Exception as e:
1076
+ # Wrap the error in the callee with the location of the call.
1077
+ raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
1078
+
1079
+ callee_ret_type = generator.ret_type
1080
+ self.function_ret_types[fn_name] = callee_ret_type
1081
+ else:
1082
+ callee_ret_type = self.function_ret_types[fn_name]
1083
+ symbol = self.module.get_function(fn_name)
1084
+ call_op = self.builder.call(symbol, arg_vals)
1085
+ if call_op.get_num_results() == 0 or callee_ret_type is None:
1086
+ return None
1087
+ elif call_op.get_num_results() == 1:
1088
+ return tensor(call_op.get_result(0), callee_ret_type)
1089
+ else:
1090
+ # should return a tuple of tl.tensor
1091
+ results = []
1092
+ for i in range(call_op.get_num_results()):
1093
+ results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
1094
+ return tuple(results)
1095
+
1096
+ def visit_Call(self, node):
1097
+ fn = _unwrap_if_constexpr(self.visit(node.func))
1098
+ static_implementation = self.statically_implemented_functions.get(fn)
1099
+ if static_implementation is not None:
1100
+ return static_implementation(self, node)
1101
+
1102
+ kws = dict(self.visit(keyword) for keyword in node.keywords)
1103
+ args = [self.visit(arg) for arg in node.args]
1104
+ if fn is language.core.device_assert: # TODO: this should not be so hardcoded
1105
+ if not self.debug:
1106
+ return
1107
+ if isinstance(fn, JITFunction):
1108
+ _check_fn_args(node, fn, args)
1109
+ return self.call_JitFunction(fn, args, kws)
1110
+ if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
1111
+ extra_kwargs = dict(_builder=self.builder)
1112
+ sig = inspect.signature(fn)
1113
+ if '_generator' in sig.parameters:
1114
+ extra_kwargs['_generator'] = self
1115
+ try:
1116
+ return fn(*args, **extra_kwargs, **kws)
1117
+ except Exception as e:
1118
+ # Normally when we raise a CompilationError, we raise it as
1119
+ # `from None`, because the original fileline from the exception
1120
+ # is not relevant (and often points into code_generator.py
1121
+ # itself). But when calling a function, we raise as `from e` to
1122
+ # preserve the traceback of the original error, which may e.g.
1123
+ # be in core.py.
1124
+ raise CompilationError(self.jit_fn.src, node, None) from e
1125
+
1126
+ if fn in self.builtin_namespace.values():
1127
+ args = map(_unwrap_if_constexpr, args)
1128
+ return fn(*args, **kws)
1129
+
1130
+ def visit_Constant(self, node):
1131
+ return constexpr(node.value)
1132
+
1133
+ def visit_BoolOp(self, node: ast.BoolOp):
1134
+ if len(node.values) != 2:
1135
+ raise self._unsupported(
1136
+ node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
1137
+ lhs = self.visit(node.values[0])
1138
+ rhs = self.visit(node.values[1])
1139
+ method_name = self._method_name_for_bool_op.get(type(node.op))
1140
+ if method_name is None:
1141
+ raise self._unsupported(
1142
+ node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
1143
+ return self._apply_binary_method(method_name, lhs, rhs)
1144
+
1145
+ _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
1146
+
1147
+ if sys.version_info < (3, 8):
1148
+
1149
+ def visit_NameConstant(self, node):
1150
+ return constexpr(node.value)
1151
+
1152
+ def visit_Num(self, node):
1153
+ return constexpr(node.n)
1154
+
1155
+ def visit_Str(self, node):
1156
+ return constexpr(ast.literal_eval(node))
1157
+
1158
+ def visit_Attribute(self, node):
1159
+ lhs = self.visit(node.value)
1160
+ if _is_triton_tensor(lhs):
1161
+ if node.attr == "T":
1162
+ return language.semantic.permute(lhs, (1, 0), builder=self.builder)
1163
+ return getattr(lhs, node.attr)
1164
+
1165
+ def visit_Expr(self, node):
1166
+ ast.NodeVisitor.generic_visit(self, node)
1167
+
1168
+ def visit_NoneType(self, node):
1169
+ return None
1170
+
1171
+ def visit_JoinedStr(self, node):
1172
+ values = list(node.values)
1173
+ for i, value in enumerate(values):
1174
+ if isinstance(value, ast.Constant):
1175
+ values[i] = str(value.value)
1176
+ elif isinstance(value, ast.FormattedValue):
1177
+ conversion_code = value.conversion
1178
+ evaluated = self.visit(value.value)
1179
+ if not _is_constexpr(evaluated):
1180
+ raise self._unsupported(
1181
+ node,
1182
+ "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
1183
+ + str(type(evaluated)))
1184
+ values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
1185
+ else:
1186
+ raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
1187
+ return ''.join(values)
1188
+
1189
+ def visit(self, node):
1190
+ if node is None:
1191
+ return
1192
+ with warnings.catch_warnings():
1193
+ # The ast library added visit_Constant and deprecated some other
1194
+ # methods but we can't move to that without breaking Python 3.6 and 3.7.
1195
+ warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
1196
+ warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
1197
+ last_node = self.cur_node
1198
+ last_loc = self.builder.get_loc()
1199
+ self.cur_node = node
1200
+ if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
1201
+ self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
1202
+ last_loc = self.builder.get_loc()
1203
+ try:
1204
+ ret = super().visit(node)
1205
+ except CompilationError:
1206
+ raise
1207
+ except Exception as e:
1208
+ # Wrap the error in a CompilationError which contains the source
1209
+ # of the @jit function.
1210
+ raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
1211
+
1212
+ # Reset the location to the last one before the visit
1213
+ if last_loc:
1214
+ self.cur_node = last_node
1215
+ self.builder.set_loc(last_loc)
1216
+ return ret
1217
+
1218
+ def generic_visit(self, node):
1219
+ raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__))
1220
+
1221
+ def execute_static_assert(self, node: ast.Call) -> None:
1222
+ arg_count = len(node.args)
1223
+ if not (0 < arg_count <= 2) or len(node.keywords):
1224
+ raise TypeError("`static_assert` requires one or two positional arguments only")
1225
+
1226
+ passed = _unwrap_if_constexpr(self.visit(node.args[0]))
1227
+ if not isinstance(passed, bool):
1228
+ raise NotImplementedError(
1229
+ "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
1230
+ )
1231
+ if not passed:
1232
+ if arg_count == 1:
1233
+ message = ""
1234
+ else:
1235
+ try:
1236
+ message = self.visit(node.args[1])
1237
+ except Exception as e:
1238
+ message = "<failed to evaluate assertion message: " + repr(e) + ">"
1239
+
1240
+ raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message))
1241
+ return None
1242
+
1243
+ def static_executor(python_fn):
1244
+
1245
+ def ret(self, node: ast.Call):
1246
+ kws = {
1247
+ name: _unwrap_if_constexpr(value)
1248
+ for name, value in (self.visit(keyword) for keyword in node.keywords)
1249
+ }
1250
+ args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
1251
+ return constexpr(python_fn(*args, **kws))
1252
+
1253
+ return ret
1254
+
1255
+ statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
1256
+ language.core.static_assert: execute_static_assert,
1257
+ language.core.static_print: static_executor(print),
1258
+ int: static_executor(int),
1259
+ len: static_executor(len),
1260
+ }
1261
+
1262
+
1263
+ def kernel_suffix(signature, specialization):
1264
+ # suffix format:
1265
+ # <argid><'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8>
1266
+ suffix = ''
1267
+ for i, _ in enumerate(signature):
1268
+ suffix += str(i)
1269
+ if i in specialization.equal_to_1:
1270
+ suffix += 'c'
1271
+ if i in specialization.divisible_by_16:
1272
+ suffix += 'd'
1273
+ return suffix
1274
+
1275
+
1276
+ def ast_to_ttir(fn, specialization, context, options, codegen_fns):
1277
+ attrs = specialization.attrs
1278
+ # create kernel prototype
1279
+ cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
1280
+ constants = {cst_key(key): value for key, value in specialization.constants.items()}
1281
+ # visit kernel AST
1282
+ gscope = fn.__globals__.copy()
1283
+ function_name = fn.repr(specialization)
1284
+ tys = list(specialization.signature.values())
1285
+ new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1}
1286
+ new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16}
1287
+
1288
+ all_constants = constants.copy()
1289
+ all_constants.update(new_constants)
1290
+ arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
1291
+ file_name, begin_line = _get_fn_file_line(fn)
1292
+
1293
+ prototype = language.function_type([], arg_types)
1294
+ generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
1295
+ jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
1296
+ begin_line=begin_line, options=options, codegen_fns=codegen_fns)
1297
+ generator.visit(fn.parse())
1298
+
1299
+ ret = generator.module
1300
+ # module takes ownership of the context
1301
+ ret.context = context
1302
+ return ret