triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,20 +1,22 @@
1
1
  import ast
2
+ import builtins
3
+ import contextlib
4
+ import copy
2
5
  import inspect
3
6
  import re
4
7
  import warnings
5
- import os
6
8
  import textwrap
7
9
  import itertools
10
+ from dataclasses import dataclass
8
11
  from types import ModuleType
9
12
  from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
10
13
 
11
- from .. import language
12
- from .._C.libtriton import ir
13
- from ..language import constexpr, semantic, str_to_ty, tensor
14
- from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, base_value, base_type
15
- from ..runtime.jit import get_jit_fn_file_line
14
+ from .. import knobs, language
15
+ from .._C.libtriton import ir, gluon_ir
16
+ from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
17
+ from ..language.core import _unwrap_if_constexpr, base_value, base_type
16
18
  # ideally we wouldn't need any runtime component
17
- from ..runtime import JITFunction
19
+ from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
18
20
  from .._utils import find_paths_if, get_iterable_path, set_iterable_path
19
21
 
20
22
  from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
@@ -27,35 +29,17 @@ def check_identifier_legality(name, type):
27
29
  return name
28
30
 
29
31
 
30
- def mangle_ty(ty):
31
- if ty.is_tuple():
32
- return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
33
- if ty.is_ptr():
34
- return 'P' + mangle_ty(ty.element_ty)
35
- if ty.is_int():
36
- SIGNED = language.dtype.SIGNEDNESS.SIGNED
37
- prefix = 'i' if ty.int_signedness == SIGNED else 'u'
38
- return prefix + str(ty.int_bitwidth)
39
- if ty.is_floating():
40
- return str(ty)
41
- if ty.is_block():
42
- elt = mangle_ty(ty.scalar)
43
- shape = '_'.join(map(str, ty.shape))
44
- return f'{elt}S{shape}S'
45
- if ty.is_void():
46
- return 'V'
47
- raise TypeError(f'Unsupported type {ty}')
48
-
49
-
50
- def mangle_fn(name, arg_tys, constants):
32
+ def mangle_fn(name, arg_tys, constants, caller_context):
51
33
  # doesn't mangle ret type, which must be a function of arg tys
52
- mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
34
+ mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
53
35
  mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
54
36
  mangled_constants = mangled_constants.replace('.', '_d_')
55
37
  mangled_constants = mangled_constants.replace("'", '_sq_')
56
38
  # [ and ] are not allowed in LLVM identifiers
57
39
  mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
58
40
  ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
41
+ if caller_context is not None:
42
+ ret += caller_context.mangle()
59
43
  return ret
60
44
 
61
45
 
@@ -68,11 +52,11 @@ def _is_triton_tensor(o: Any) -> bool:
68
52
 
69
53
 
70
54
  def _is_constexpr(o: Any) -> bool:
71
- return o is None or isinstance(o, (constexpr, language.core.dtype))
55
+ return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
72
56
 
73
57
 
74
- def _is_triton_scalar(o: Any) -> bool:
75
- return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
58
+ def _is_non_scalar_tensor(o: Any) -> bool:
59
+ return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
76
60
 
77
61
 
78
62
  def _is_list_like(o: Any) -> bool:
@@ -82,7 +66,7 @@ def _is_list_like(o: Any) -> bool:
82
66
  def _check_fn_args(node, fn, args):
83
67
  if fn.noinline:
84
68
  for idx, arg in enumerate(args):
85
- if not _is_constexpr(arg) and not _is_triton_scalar(arg):
69
+ if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
86
70
  raise UnsupportedLanguageConstruct(
87
71
  fn.src, node,
88
72
  f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
@@ -102,6 +86,7 @@ def _apply_to_tuple_values(value, fn):
102
86
  assert False, f"Unsupported type {type(value)}"
103
87
 
104
88
  vals = [fn(v) for v in value]
89
+ vals = [constexpr(v) if v is None else v for v in vals]
105
90
  types = [v.type for v in vals]
106
91
  return language.tuple(vals, language.tuple_type(types, fields))
107
92
 
@@ -124,6 +109,17 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
124
109
  _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
125
110
 
126
111
 
112
+ def _clone_triton_value(val):
113
+ handles = []
114
+ val._flatten_ir(handles)
115
+ clone, _ = val.type._unflatten_ir(handles, 0)
116
+ return clone
117
+
118
+
119
+ def _clone_scope(scope):
120
+ return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
121
+
122
+
127
123
  class enter_sub_region:
128
124
 
129
125
  def __init__(self, generator):
@@ -131,8 +127,8 @@ class enter_sub_region:
131
127
 
132
128
  def __enter__(self):
133
129
  # record lscope & local_defs in the parent scope
134
- self.liveins = self.generator.lscope.copy()
135
- self.prev_defs = self.generator.local_defs.copy()
130
+ self.liveins = _clone_scope(self.generator.lscope)
131
+ self.prev_defs = _clone_scope(self.generator.local_defs)
136
132
  self.generator.local_defs = {}
137
133
  self.insert_block = self.generator.builder.get_insertion_block()
138
134
  self.insert_point = self.generator.builder.get_insertion_point()
@@ -154,10 +150,9 @@ class ContainsReturnChecker(ast.NodeVisitor):
154
150
  return any(self.visit(s) for s in body)
155
151
 
156
152
  def _visit_function(self, fn) -> bool:
157
- # Currently we only support JITFunctions defined in the global scope
158
- if isinstance(fn, JITFunction) and not fn.noinline:
159
- fn_node = fn.parse()
160
- return ContainsReturnChecker(self.gscope).visit(fn_node)
153
+ # No need to check within the function as it won't cause an early return.
154
+ # If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
155
+ # we should check for this and emit a warning.
161
156
  return False
162
157
 
163
158
  def generic_visit(self, node) -> bool:
@@ -241,26 +236,26 @@ class ASTFunction:
241
236
  self.constants = constants
242
237
  self.attrs = attrs
243
238
 
244
- def return_types_ir(self, builder: ir.builder):
245
- ret_types = []
246
- for ret_ty in self.ret_types:
247
- if ret_ty is None:
239
+ def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
240
+ ir_types = []
241
+ for ty in types:
242
+ if ty is None:
248
243
  continue
249
- ir_ty = ret_ty.to_ir(builder)
250
- if isinstance(ir_ty, list):
251
- ret_types.extend(ir_ty)
252
- else:
253
- ret_types.append(ir_ty)
254
- return ret_types
244
+ ty._flatten_ir_types(builder, ir_types)
245
+ return ir_types
246
+
247
+ def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
248
+ return self.flatten_ir_types(builder, self.ret_types)
255
249
 
256
250
  def serialize(self, builder: ir.builder):
257
251
  # fill up IR values in template
258
252
  # > build function
259
253
  is_val = lambda path, _: path not in self.constants and _ is not None
260
254
  val_paths = list(find_paths_if(self.arg_types, is_val))
261
- arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
262
- ret_types = self.return_types_ir(builder)
263
- return builder.get_function_ty(arg_types, ret_types)
255
+ arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
256
+ arg_types_ir = self.flatten_ir_types(builder, arg_types)
257
+ ret_types_ir = self.return_types_ir(builder)
258
+ return builder.get_function_ty(arg_types_ir, ret_types_ir)
264
259
 
265
260
  def deserialize(self, fn):
266
261
  # create "template"
@@ -272,19 +267,18 @@ class ASTFunction:
272
267
  vals = make_template(self.arg_types)
273
268
  is_val = lambda path, _: path not in self.constants and _ is not None
274
269
  val_paths = list(find_paths_if(self.arg_types, is_val))
275
- # > set attributes
276
- for attr_path, attr_specs in self.attrs.items():
277
- for attr_name, attr_val in attr_specs:
278
- if attr_path in val_paths:
279
- fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
280
- for i, path in enumerate(val_paths):
281
- ty = get_iterable_path(self.arg_types, path)
282
- if isinstance(ty, nv_tma_desc_type):
283
- fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
284
270
  # > add IR values to the template
285
- for i, path in enumerate(val_paths):
271
+ cursor = 0
272
+ handles = [fn.args(i) for i in range(fn.get_num_args())]
273
+ for path in val_paths:
286
274
  ty = get_iterable_path(self.arg_types, path)
287
- set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
275
+ # > set attributes
276
+ attr_specs = self.attrs.get(path, [])
277
+ for attr_name, attr_val in attr_specs:
278
+ fn.set_arg_attr(cursor, attr_name, attr_val)
279
+ # > build frontend value
280
+ val, cursor = ty._unflatten_ir(handles, cursor)
281
+ set_iterable_path(vals, path, val)
288
282
  # > add constexpr values to the template
289
283
  constants = self.constants
290
284
  for path, val in constants.items():
@@ -292,13 +286,29 @@ class ASTFunction:
292
286
  return vals
293
287
 
294
288
 
289
+ @dataclass(frozen=True)
290
+ class BoundJITMethod:
291
+ __self__: base_value
292
+ __func__: JITFunction
293
+
294
+
295
295
  class CodeGenerator(ast.NodeVisitor):
296
296
 
297
- def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
298
- module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
299
- file_name: Optional[str] = None, begin_line=0):
297
+ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
298
+ module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
299
+ noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
300
300
  self.context = context
301
- self.builder = ir.builder(context)
301
+ self.is_gluon = is_gluon
302
+ if is_gluon:
303
+ from triton.experimental.gluon.language._semantic import GluonSemantic
304
+ self.builder = gluon_ir.GluonOpBuilder(context)
305
+ self.semantic = GluonSemantic(self.builder)
306
+ else:
307
+ from triton.language.semantic import TritonSemantic
308
+ self.builder = ir.builder(context)
309
+ self.semantic = TritonSemantic(self.builder)
310
+
311
+ self.name_loc_as_prefix = None
302
312
  self.file_name = file_name
303
313
  # node.lineno starts from 1, so we need to subtract 1
304
314
  self.begin_line = begin_line - 1
@@ -306,7 +316,7 @@ class CodeGenerator(ast.NodeVisitor):
306
316
  self.builder.options = options
307
317
  # dict of functions provided by the backend. Below are the list of possible functions:
308
318
  # Convert custom types not natively supported on HW.
309
- # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
319
+ # convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None)
310
320
  self.builder.codegen_fns = codegen_fns
311
321
  self.builder.module_map = {} if module_map is None else module_map
312
322
  self.module = self.builder.create_module() if module is None else module
@@ -329,11 +339,13 @@ class CodeGenerator(ast.NodeVisitor):
329
339
  self.jit_fn = jit_fn
330
340
  # TODO: we currently generate illegal names for non-kernel functions involving constexprs!
331
341
  if is_kernel:
342
+ function_name = function_name[function_name.rfind('.') + 1:]
332
343
  function_name = check_identifier_legality(function_name, "function")
333
344
  self.function_name = function_name
334
345
  self.is_kernel = is_kernel
335
346
  self.cur_node = None
336
347
  self.noinline = noinline
348
+ self.caller_context = caller_context
337
349
  self.scf_stack = []
338
350
  self.ret_type = None
339
351
  # SSA-construction
@@ -345,7 +357,10 @@ class CodeGenerator(ast.NodeVisitor):
345
357
  # special handling.
346
358
  self.visiting_arg_default_value = False
347
359
 
348
- builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)}
360
+ builtin_namespace: Dict[str, Any] = {
361
+ _.__name__: _
362
+ for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
363
+ }
349
364
  builtin_namespace.update((
350
365
  ('print', language.core.device_print),
351
366
  ('min', language.minimum),
@@ -378,11 +393,14 @@ class CodeGenerator(ast.NodeVisitor):
378
393
  # But actually a bunch of other things, such as module imports, are
379
394
  # technically Python globals. We have to allow these too!
380
395
  if any([
381
- val is absent, name in self.builtin_namespace, #
396
+ val is absent,
397
+ name in self.builtin_namespace, #
382
398
  type(val) is ModuleType, #
383
- isinstance(val, JITFunction), #
399
+ isinstance(val, JITCallable), #
384
400
  getattr(val, "__triton_builtin__", False), #
401
+ getattr(val, "__triton_aggregate__", False), #
385
402
  getattr(val, "__module__", "").startswith("triton.language"), #
403
+ getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), #
386
404
  isinstance(val, language.dtype), #
387
405
  _is_namedtuple(val),
388
406
  self._is_constexpr_global(name), #
@@ -390,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor):
390
408
  # because you should be able to do
391
409
  # @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
392
410
  self.visiting_arg_default_value, #
393
- os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"
411
+ knobs.compilation.allow_non_constexpr_globals,
394
412
  ]):
395
413
  return val
396
414
  raise NameError(
@@ -414,6 +432,21 @@ class CodeGenerator(ast.NodeVisitor):
414
432
 
415
433
  return name_lookup
416
434
 
435
+ @contextlib.contextmanager
436
+ def _name_loc_prefix(self, prefix):
437
+ self.name_loc_as_prefix = prefix
438
+ yield
439
+ self.name_loc_as_prefix = None
440
+
441
+ def _maybe_set_loc_to_name(self, val, name):
442
+ if isinstance(val, (ir.value, ir.block_argument)):
443
+ val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
444
+ elif _is_triton_value(val):
445
+ handles = []
446
+ val._flatten_ir(handles)
447
+ for handle in handles:
448
+ handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
449
+
417
450
  def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
418
451
  ''' This function:
419
452
  called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
@@ -435,6 +468,43 @@ class CodeGenerator(ast.NodeVisitor):
435
468
  self.builder.restore_insertion_point(ip)
436
469
  self.builder.set_loc(loc)
437
470
 
471
+ def _find_carries(self, node, liveins):
472
+ # create loop body block
473
+ block = self.builder.create_block()
474
+ self.builder.set_insertion_point_to_start(block)
475
+ # dry visit loop body
476
+ self.scf_stack.append(node)
477
+ self.visit_compound_statement(node.body)
478
+ self.scf_stack.pop()
479
+ block.erase()
480
+
481
+ # If a variable (name) has changed value within the loop, then it's
482
+ # a loop-carried variable. (The new and old value must be of the
483
+ # same type)
484
+ init_tys = []
485
+ init_handles = []
486
+ names = []
487
+
488
+ for name, live_val in liveins.items():
489
+ if _is_triton_value(live_val):
490
+ loop_val = self.lscope[name]
491
+ self._verify_loop_carried_variable(name, loop_val, live_val)
492
+
493
+ live_handles = flatten_values_to_ir([live_val])
494
+ loop_handles = flatten_values_to_ir([loop_val])
495
+ if live_handles != loop_handles:
496
+ names.append(name)
497
+ init_tys.append(live_val.type)
498
+ init_handles.extend(live_handles)
499
+ else:
500
+ assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
501
+
502
+ # reset local scope to not pick up local defs from the dry run.
503
+ self.lscope = liveins.copy()
504
+ self.local_defs = {}
505
+
506
+ return names, init_handles, init_tys
507
+
438
508
  #
439
509
  # AST visitor
440
510
  #
@@ -458,6 +528,21 @@ class CodeGenerator(ast.NodeVisitor):
458
528
  elts = language.tuple([self.visit(elt) for elt in node.elts])
459
529
  return elts
460
530
 
531
+ def visit_ListComp(self, node: ast.ListComp):
532
+ if len(node.generators) != 1:
533
+ raise ValueError("nested comprehensions are not supported")
534
+
535
+ comp = node.generators[0]
536
+ iter = self.visit(comp.iter)
537
+ if not isinstance(iter, tl_tuple):
538
+ raise NotImplementedError("only tuple comprehensions are supported")
539
+
540
+ results = []
541
+ for item in iter:
542
+ self.set_value(comp.target.id, item)
543
+ results.append(self.visit(node.elt))
544
+ return tl_tuple(results)
545
+
461
546
  # By design, only non-kernel functions can return
462
547
  def visit_Return(self, node):
463
548
  ret_value = self.visit(node.value)
@@ -467,7 +552,7 @@ class CodeGenerator(ast.NodeVisitor):
467
552
  if isinstance(value, language.tuple):
468
553
  return _apply_to_tuple_values(value, decay)
469
554
  elif isinstance(value, (language.constexpr, int, float)):
470
- return semantic.to_tensor(value, self.builder)
555
+ return self.semantic.to_tensor(value)
471
556
  return value
472
557
 
473
558
  ret_value = decay(ret_value)
@@ -522,8 +607,11 @@ class CodeGenerator(ast.NodeVisitor):
522
607
  self.module.push_back(self.fn)
523
608
  entry = self.fn.add_entry_block()
524
609
  arg_values = self.prototype.deserialize(self.fn)
610
+ if self.caller_context is not None:
611
+ self.caller_context.initialize_callee(self.fn, self.builder)
525
612
  # bind arguments to symbols
526
613
  for arg_name, arg_value in zip(arg_names, arg_values):
614
+ self._maybe_set_loc_to_name(arg_value, arg_name)
527
615
  self.set_value(arg_name, arg_value)
528
616
  insert_pt = self.builder.get_insertion_block()
529
617
  self.builder.set_insertion_point_to_start(entry)
@@ -575,14 +663,15 @@ class CodeGenerator(ast.NodeVisitor):
575
663
  return self.visit_Assign(node)
576
664
 
577
665
  def assignTarget(self, target, value):
666
+ assert isinstance(target.ctx, ast.Store)
578
667
  if isinstance(target, ast.Subscript):
579
- assert target.ctx.__class__.__name__ == "Store"
580
668
  return self.visit_Subscript_Store(target, value)
581
669
  if isinstance(target, ast.Tuple):
582
- assert target.ctx.__class__.__name__ == "Store"
583
- for i, name in enumerate(target.elts):
584
- self.set_value(self.visit(name), value.values[i])
670
+ for i, target in enumerate(target.elts):
671
+ self.assignTarget(target, value.values[i])
585
672
  return
673
+ if isinstance(target, ast.Attribute):
674
+ raise NotImplementedError("Attribute assignment is not supported in triton")
586
675
  assert isinstance(target, ast.Name)
587
676
  self.set_value(self.visit(target), value)
588
677
 
@@ -596,21 +685,26 @@ class CodeGenerator(ast.NodeVisitor):
596
685
  if value is not None and \
597
686
  not _is_triton_value(value) and \
598
687
  not isinstance(value, native_nontensor_types):
599
- value = semantic.to_tensor(value, self.builder)
688
+ value = self.semantic.to_tensor(value)
600
689
  return value
601
690
 
602
- values = _sanitize_value(self.visit(node.value))
603
691
  targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
604
692
  assert len(targets) == 1
605
- self.assignTarget(targets[0], values)
693
+ target = targets[0]
694
+ if isinstance(target, ast.Name):
695
+ with self._name_loc_prefix(target.id):
696
+ values = _sanitize_value(self.visit(node.value))
697
+ else:
698
+ values = _sanitize_value(self.visit(node.value))
699
+ self.assignTarget(target, values)
606
700
 
607
701
  def visit_AugAssign(self, node):
608
- name = node.target.id
609
- lhs = ast.Name(id=name, ctx=ast.Load())
702
+ lhs = copy.deepcopy(node.target)
703
+ lhs.ctx = ast.Load()
610
704
  rhs = ast.BinOp(lhs, node.op, node.value)
611
705
  assign = ast.Assign(targets=[node.target], value=rhs)
612
706
  self.visit(assign)
613
- return self.dereference_name(name)
707
+ return self.visit(lhs)
614
708
 
615
709
  def visit_Name(self, node):
616
710
  if type(node.ctx) is ast.Store:
@@ -630,10 +724,12 @@ class CodeGenerator(ast.NodeVisitor):
630
724
  def _apply_binary_method(self, method_name, lhs, rhs):
631
725
  # TODO: raise something meaningful if getattr fails below, esp for reverse method
632
726
  if _is_triton_tensor(lhs):
633
- return getattr(lhs, method_name)(rhs, _builder=self.builder)
727
+ return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
634
728
  if _is_triton_tensor(rhs):
635
729
  reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
636
- return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder)
730
+ return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
731
+ if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
732
+ lhs = constexpr(lhs)
637
733
  return getattr(lhs, method_name)(rhs)
638
734
 
639
735
  def visit_BinOp(self, node):
@@ -666,8 +762,10 @@ class CodeGenerator(ast.NodeVisitor):
666
762
  self.visit_compound_statement(node.body)
667
763
  then_block = self.builder.get_insertion_block()
668
764
  then_defs = self.local_defs.copy()
765
+ then_vals = self.lscope.copy()
669
766
  # else block
670
767
  else_defs = {}
768
+ else_vals = liveins.copy()
671
769
  if node.orelse:
672
770
  self.builder.set_insertion_point_to_start(else_block)
673
771
  self.lscope = liveins.copy()
@@ -675,26 +773,29 @@ class CodeGenerator(ast.NodeVisitor):
675
773
  self.visit_compound_statement(node.orelse)
676
774
  else_defs = self.local_defs.copy()
677
775
  else_block = self.builder.get_insertion_block()
776
+ else_vals = self.lscope.copy()
678
777
 
679
778
  # update block arguments
680
779
  names = []
681
780
  # variables in livein whose value is updated in `if`
682
- for name in liveins:
781
+ for name, value in liveins.items():
782
+ # livein variable changed value in either then or else
783
+ if not _is_triton_value(value):
784
+ continue
785
+ then_handles = flatten_values_to_ir([then_vals[name]])
786
+ else_handles = flatten_values_to_ir([else_vals[name]])
787
+ if then_handles == else_handles:
788
+ continue
789
+ names.append(name)
790
+ then_defs[name] = then_vals[name]
791
+ else_defs[name] = else_vals[name]
683
792
  # check type
684
793
  for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
685
- if name in defs:
686
- type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721
687
- assert type_equal and defs[name].type == liveins[name].type, \
688
- f'initial value for `{name}` is of type {liveins[name]}, '\
689
- f'but the {block_name} block redefines it as {defs[name]}'
690
- if name in then_defs or name in else_defs:
691
- names.append(name)
692
- # variable defined in then but not in else
693
- if name in then_defs and name not in else_defs:
694
- else_defs[name] = liveins[name]
695
- # variable defined in else but not in then
696
- if name in else_defs and name not in then_defs:
697
- then_defs[name] = liveins[name]
794
+ type_equal = type(defs[name]) == type(value) # noqa: E721
795
+ assert type_equal and defs[name].type == value.type, \
796
+ f'initial value for `{name}` is of type {value}, '\
797
+ f'but the {block_name} block redefines it as {defs[name]}'
798
+
698
799
  # variables that are both in then and else but not in liveins
699
800
  # TODO: could probably be cleaned up
700
801
  for name in sorted(then_defs.keys() & else_defs.keys()):
@@ -761,6 +862,8 @@ class CodeGenerator(ast.NodeVisitor):
761
862
  self.visit_then_else_blocks(node, liveins, then_block, else_block)
762
863
  # create if op
763
864
  then_handles = flatten_values_to_ir(then_defs[name] for name in names)
865
+ for name, val in zip(names, then_handles):
866
+ self._maybe_set_loc_to_name(val, name)
764
867
  self._set_insertion_point_and_loc(ip, last_loc)
765
868
  if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
766
869
  then_block.merge_block_before(if_op.get_then_block())
@@ -774,6 +877,8 @@ class CodeGenerator(ast.NodeVisitor):
774
877
  self.builder.set_insertion_point_to_end(if_op.get_else_block())
775
878
  if len(names) > 0:
776
879
  else_handles = flatten_values_to_ir(else_defs[name] for name in names)
880
+ for name, val in zip(names, else_handles):
881
+ self._maybe_set_loc_to_name(val, name)
777
882
  self.builder.create_yield_op(else_handles)
778
883
  # update values
779
884
  res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
@@ -786,14 +891,18 @@ class CodeGenerator(ast.NodeVisitor):
786
891
  cond = self.visit(node.test)
787
892
 
788
893
  if _is_triton_tensor(cond):
789
- cond = cond.to(language.int1, _builder=self.builder)
790
- contains_return = ContainsReturnChecker(self.gscope).visit(node)
791
- if contains_return:
894
+ if _is_non_scalar_tensor(cond):
895
+ raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
896
+ if cond.type.is_block():
897
+ warnings.warn(
898
+ "If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
899
+ % ast.unparse(node.test))
900
+ cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
901
+ cond = cond.to(language.int1, _semantic=self.semantic)
902
+ if ContainsReturnChecker(self.gscope).visit(node):
792
903
  if self.scf_stack:
793
904
  raise self._unsupported(
794
- node, "Cannot have `return` statements inside `while` or `for` statements in triton "
795
- "(note that this also applies to `return` statements that are inside functions "
796
- "transitively called from within `while`/`for` statements)")
905
+ node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
797
906
  self.visit_if_top_level(cond, node)
798
907
  else:
799
908
  self.visit_if_scf(cond, node)
@@ -812,21 +921,21 @@ class CodeGenerator(ast.NodeVisitor):
812
921
  def visit_IfExp(self, node):
813
922
  cond = self.visit(node.test)
814
923
  if _is_triton_tensor(cond):
815
- cond = cond.to(language.int1, _builder=self.builder)
924
+ cond = cond.to(language.int1, _semantic=self.semantic)
816
925
  # TODO: Deal w/ more complicated return types (e.g tuple)
817
926
  with enter_sub_region(self):
818
927
  ip, last_loc = self._get_insertion_point_and_loc()
819
928
 
820
929
  then_block = self.builder.create_block()
821
930
  self.builder.set_insertion_point_to_start(then_block)
822
- then_val = semantic.to_tensor(self.visit(node.body), self.builder)
931
+ then_val = self.semantic.to_tensor(self.visit(node.body))
823
932
  then_block = self.builder.get_insertion_block()
824
933
 
825
934
  else_block = self.builder.create_block()
826
935
  self.builder.set_insertion_point_to_start(else_block)
827
936
  # do not need to reset lscope since
828
937
  # ternary expressions cannot define new variables
829
- else_val = semantic.to_tensor(self.visit(node.orelse), self.builder)
938
+ else_val = self.semantic.to_tensor(self.visit(node.orelse))
830
939
  else_block = self.builder.get_insertion_block()
831
940
 
832
941
  self._set_insertion_point_and_loc(ip, last_loc)
@@ -862,6 +971,37 @@ class CodeGenerator(ast.NodeVisitor):
862
971
  else:
863
972
  return self.visit(node.orelse)
864
973
 
974
+ def visit_With(self, node):
975
+ # Lower `with` statements by constructing context managers and calling their enter/exit hooks
976
+ # Instantiate each context manager with builder injection
977
+ if len(node.items) == 1: # Handle async_task
978
+ context = node.items[0].context_expr
979
+ withitemClass = self.visit(context.func)
980
+ if withitemClass == language.async_task:
981
+ args = [self.visit(arg) for arg in context.args]
982
+ with withitemClass(*args, _builder=self.builder):
983
+ self.visit_compound_statement(node.body)
984
+ return
985
+
986
+ cm_list = []
987
+ for item in node.items:
988
+ call = item.context_expr
989
+ fn = self.visit(call.func)
990
+ args = [self.visit(arg) for arg in call.args]
991
+ kws = dict(self.visit(kw) for kw in call.keywords)
992
+ cm = fn(*args, _semantic=self.semantic, **kws)
993
+ cm_list.append(cm)
994
+ for cm, item in zip(cm_list, node.items):
995
+ res = cm.__enter__()
996
+ if item.optional_vars is not None:
997
+ var_name = self.visit(item.optional_vars)
998
+ self.set_value(var_name, res)
999
+ if ContainsReturnChecker(self.gscope).visit(node):
1000
+ raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
1001
+ self.visit_compound_statement(node.body)
1002
+ for cm in reversed(cm_list):
1003
+ cm.__exit__(None, None, None)
1004
+
865
1005
  def visit_Pass(self, node):
866
1006
  pass
867
1007
 
@@ -892,10 +1032,12 @@ class CodeGenerator(ast.NodeVisitor):
892
1032
  if fn is None:
893
1033
  raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
894
1034
  if _is_triton_tensor(operand):
895
- return getattr(operand, fn)(_builder=self.builder)
1035
+ return getattr(operand, fn)(_semantic=self.semantic)
896
1036
  try:
897
1037
  return getattr(operand, fn)()
898
1038
  except AttributeError:
1039
+ if fn == "__not__":
1040
+ return constexpr(not operand)
899
1041
  raise self._unsupported(
900
1042
  node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
901
1043
 
@@ -904,46 +1046,26 @@ class CodeGenerator(ast.NodeVisitor):
904
1046
  }
905
1047
 
906
1048
  def _verify_loop_carried_variable(self, name, loop_val, live_val):
907
- assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
908
- assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
909
- assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type'
1049
+ assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
1050
+ assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
1051
+ assert type(loop_val) is type(live_val), (
1052
+ f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
910
1053
  assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
911
1054
  f'Loop-carried variable {name} has initial type {live_val.type} '\
912
1055
  f'but is re-assigned to {loop_val.type} in loop! '\
913
1056
  f'Please make sure that the type stays consistent.'
914
1057
 
1058
+ def visit_withitem(self, node):
1059
+ return self.visit(node.context_expr)
1060
+
915
1061
  def visit_While(self, node):
916
1062
  with enter_sub_region(self) as sr:
917
1063
  liveins, insert_block = sr
918
1064
  ip, last_loc = self._get_insertion_point_and_loc()
919
1065
 
920
- # loop body (the after region)
921
- # loop_block = self.builder.create_block()
922
- dummy = self.builder.create_block()
923
- self.builder.set_insertion_point_to_start(dummy)
924
- self.scf_stack.append(node)
925
- self.visit_compound_statement(node.body)
926
- self.scf_stack.pop()
927
- loop_defs = self.local_defs
928
- dummy.erase()
929
-
930
- # collect loop-carried values
931
- names = []
932
- init_args = []
933
- for name in loop_defs:
934
- if name in liveins:
935
- # We should not def new constexpr
936
- loop_val = loop_defs[name]
937
- live_val = liveins[name]
938
- self._verify_loop_carried_variable(name, loop_val, live_val)
939
-
940
- # these are loop-carried values
941
- names.append(name)
942
- init_args.append(live_val)
1066
+ names, init_handles, init_fe_tys = self._find_carries(node, liveins)
943
1067
 
944
- init_handles = flatten_values_to_ir(init_args)
945
1068
  init_tys = [h.get_type() for h in init_handles]
946
- init_fe_tys = [a.type for a in init_args]
947
1069
  self._set_insertion_point_and_loc(ip, last_loc)
948
1070
  while_op = self.builder.create_while_op(init_tys, init_handles)
949
1071
  # merge the condition region
@@ -954,7 +1076,12 @@ class CodeGenerator(ast.NodeVisitor):
954
1076
  for name, val in zip(names, condition_args):
955
1077
  self.lscope[name] = val
956
1078
  self.local_defs[name] = val
1079
+ self._maybe_set_loc_to_name(val, name)
957
1080
  cond = self.visit(node.test)
1081
+ if isinstance(cond, language.condition):
1082
+ if cond.disable_licm:
1083
+ while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
1084
+ cond = cond.condition
958
1085
  self.builder.set_insertion_point_to_end(before_block)
959
1086
  # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
960
1087
  self.builder.create_condition_op(cond.handle, block_args)
@@ -968,16 +1095,13 @@ class CodeGenerator(ast.NodeVisitor):
968
1095
  for name, val in zip(names, body_args):
969
1096
  self.lscope[name] = val
970
1097
  self.local_defs[name] = val
1098
+ self._maybe_set_loc_to_name(val, name)
971
1099
  self.scf_stack.append(node)
972
1100
  self.visit_compound_statement(node.body)
973
1101
  self.scf_stack.pop()
974
- loop_defs = self.local_defs
975
- yields = []
976
- for name in loop_defs:
977
- if name in liveins:
978
- loop_defs[name]._flatten_ir(yields)
979
1102
 
980
- self.builder.create_yield_op(yields)
1103
+ yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
1104
+ self.builder.create_yield_op(yield_handles)
981
1105
 
982
1106
  # WhileOp defines new values, update the symbol table (lscope, local_defs)
983
1107
  result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
@@ -985,25 +1109,22 @@ class CodeGenerator(ast.NodeVisitor):
985
1109
  for name, new_def in zip(names, result_vals):
986
1110
  self.lscope[name] = new_def
987
1111
  self.local_defs[name] = new_def
1112
+ self._maybe_set_loc_to_name(new_def, name)
988
1113
 
989
1114
  for stmt in node.orelse:
990
1115
  assert False, "Not implemented"
991
1116
  ast.NodeVisitor.generic_visit(self, stmt)
992
1117
 
993
1118
  def visit_Subscript_Load(self, node):
994
- assert node.ctx.__class__.__name__ == "Load"
1119
+ assert isinstance(node.ctx, ast.Load)
995
1120
  lhs = self.visit(node.value)
996
1121
  slices = self.visit(node.slice)
997
- if _is_triton_tensor(lhs):
998
- return lhs.__getitem__(slices, _builder=self.builder)
1122
+ if _is_triton_value(lhs):
1123
+ return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
999
1124
  return lhs[slices]
1000
1125
 
1001
1126
  def visit_Subscript_Store(self, node, value):
1002
- assert node.ctx.__class__.__name__ == "Store"
1003
- lhs = self.visit(node.value)
1004
- slices = self.visit(node.slice)
1005
- assert isinstance(lhs, language.tuple)
1006
- lhs.__setitem__(slices, value)
1127
+ raise NotImplementedError("__setitem__ is not supported in triton")
1007
1128
 
1008
1129
  def visit_Subscript(self, node):
1009
1130
  return self.visit_Subscript_Load(node)
@@ -1028,6 +1149,8 @@ class CodeGenerator(ast.NodeVisitor):
1028
1149
  loop_unroll_factor = None
1029
1150
  disallow_acc_multi_buffer = False
1030
1151
  flatten = False
1152
+ warp_specialize = False
1153
+ disable_licm = False
1031
1154
  if IteratorClass is language.range:
1032
1155
  iterator = IteratorClass(*iter_args, **iter_kwargs)
1033
1156
  # visit iterator arguments
@@ -1040,6 +1163,8 @@ class CodeGenerator(ast.NodeVisitor):
1040
1163
  loop_unroll_factor = iterator.loop_unroll_factor
1041
1164
  disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
1042
1165
  flatten = iterator.flatten
1166
+ warp_specialize = iterator.warp_specialize
1167
+ disable_licm = iterator.disable_licm
1043
1168
  elif IteratorClass is range:
1044
1169
  # visit iterator arguments
1045
1170
  # note: only `range` iterator is supported now
@@ -1055,14 +1180,14 @@ class CodeGenerator(ast.NodeVisitor):
1055
1180
  step = constexpr(-step.value)
1056
1181
  negative_step = True
1057
1182
  lb, ub = ub, lb
1058
- lb = semantic.to_tensor(lb, self.builder)
1059
- ub = semantic.to_tensor(ub, self.builder)
1060
- step = semantic.to_tensor(step, self.builder)
1183
+ lb = self.semantic.to_tensor(lb)
1184
+ ub = self.semantic.to_tensor(ub)
1185
+ step = self.semantic.to_tensor(step)
1061
1186
  # induction variable type
1062
1187
  if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
1063
1188
  raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
1064
- iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype)
1065
- iv_type = semantic.integer_promote_impl(iv_type, step.dtype)
1189
+ iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
1190
+ iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
1066
1191
  iv_ir_type = iv_type.to_ir(self.builder)
1067
1192
  iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
1068
1193
  # lb/ub/step might be constexpr, we need to cast them to tensor
@@ -1081,34 +1206,10 @@ class CodeGenerator(ast.NodeVisitor):
1081
1206
  liveins, insert_block = sr
1082
1207
  ip, last_loc = self._get_insertion_point_and_loc()
1083
1208
 
1084
- # create loop body block
1085
- block = self.builder.create_block()
1086
- self.builder.set_insertion_point_to_start(block)
1087
- # dry visit loop body
1088
- self.scf_stack.append(node)
1089
- self.visit_compound_statement(node.body)
1090
- self.scf_stack.pop()
1091
- block.erase()
1092
-
1093
- # If a variable (name) is defined in both its parent & itself, then it's
1094
- # a loop-carried variable. (They must be of the same type)
1095
- init_args = []
1096
- yields = []
1097
- names = []
1098
- for name in self.local_defs:
1099
- if name in liveins:
1100
- loop_val = self.local_defs[name]
1101
- live_val = liveins[name]
1102
- self._verify_loop_carried_variable(name, loop_val, live_val)
1103
-
1104
- names.append(name)
1105
- init_args.append(live_val)
1106
- yields.append(loop_val)
1209
+ names, init_handles, init_tys = self._find_carries(node, liveins)
1107
1210
 
1108
1211
  # create ForOp
1109
1212
  self._set_insertion_point_and_loc(ip, last_loc)
1110
- init_handles = flatten_values_to_ir(init_args)
1111
- init_tys = [v.type for v in init_args]
1112
1213
  for_op = self.builder.create_for_op(lb, ub, step, init_handles)
1113
1214
  if _unwrap_if_constexpr(num_stages) is not None:
1114
1215
  for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
@@ -1118,30 +1219,25 @@ class CodeGenerator(ast.NodeVisitor):
1118
1219
  for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
1119
1220
  if flatten:
1120
1221
  for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
1222
+ if warp_specialize:
1223
+ for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
1224
+ if disable_licm:
1225
+ for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
1121
1226
 
1122
1227
  self.scf_stack.append(node)
1123
1228
  for_op_body = for_op.get_body(0)
1124
1229
  self.builder.set_insertion_point_to_start(for_op_body)
1125
- # reset local scope to not pick up local defs from the previous dry run.
1126
- self.lscope = liveins.copy()
1127
- self.local_defs = {}
1128
1230
  block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
1129
1231
  block_args = unflatten_ir_values(block_handles, init_tys)
1130
1232
  for name, val in zip(names, block_args):
1233
+ self._maybe_set_loc_to_name(val, name)
1131
1234
  self.set_value(name, val)
1132
1235
  self.visit_compound_statement(node.body)
1133
1236
  self.scf_stack.pop()
1134
- yields = []
1135
- for name in self.local_defs:
1136
- if name in liveins:
1137
- local = self.local_defs[name]
1138
- if isinstance(local, constexpr):
1139
- local = semantic.to_tensor(local, self.builder)
1140
- yields.append(local)
1237
+ yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
1141
1238
 
1142
1239
  # create YieldOp
1143
- if len(yields) > 0:
1144
- yield_handles = flatten_values_to_ir(yields)
1240
+ if len(yield_handles) > 0:
1145
1241
  self.builder.create_yield_op(yield_handles)
1146
1242
  for_op_region = for_op_body.get_parent()
1147
1243
  assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
@@ -1154,12 +1250,14 @@ class CodeGenerator(ast.NodeVisitor):
1154
1250
  iv = self.builder.create_add(iv, lb)
1155
1251
  self.lscope[node.target.id].handle.replace_all_uses_with(iv)
1156
1252
  self.set_value(node.target.id, language.core.tensor(iv, iv_type))
1253
+ self._maybe_set_loc_to_name(iv, node.target.id)
1157
1254
 
1158
1255
  # update lscope & local_defs (ForOp defines new values)
1159
1256
  result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
1160
1257
  result_values = unflatten_ir_values(result_handles, init_tys)
1161
1258
  for name, val in zip(names, result_values):
1162
1259
  self.set_value(name, val)
1260
+ self._maybe_set_loc_to_name(val, name)
1163
1261
 
1164
1262
  for stmt in node.orelse:
1165
1263
  assert False, "Don't know what to do with else after for"
@@ -1180,9 +1278,9 @@ class CodeGenerator(ast.NodeVisitor):
1180
1278
  def visit_Assert(self, node) -> Any:
1181
1279
  test = self.visit(node.test)
1182
1280
  msg = self.visit(node.msg) if node.msg is not None else ""
1183
- return language.core.device_assert(test, msg, _builder=self.builder)
1281
+ return language.core.device_assert(test, msg, _semantic=self.semantic)
1184
1282
 
1185
- def call_JitFunction(self, fn: JITFunction, args, kwargs):
1283
+ def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
1186
1284
  args = inspect.getcallargs(fn.fn, *args, **kwargs)
1187
1285
  args = [args[name] for name in fn.arg_names]
1188
1286
  for i, arg in enumerate(args):
@@ -1193,10 +1291,10 @@ class CodeGenerator(ast.NodeVisitor):
1193
1291
  args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1194
1292
  args_val = [get_iterable_path(args, path) for path in args_path]
1195
1293
  # mangle
1196
- fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
1294
+ caller_context = caller_context or self.caller_context
1295
+ fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
1197
1296
  # generate function def if necessary
1198
1297
  if not self.module.has_function(fn_name):
1199
- gscope = fn.__globals__
1200
1298
  # If the callee is not set, we use the same debug setting as the caller
1201
1299
  file_name, begin_line = get_jit_fn_file_line(fn)
1202
1300
  arg_types = [
@@ -1205,15 +1303,18 @@ class CodeGenerator(ast.NodeVisitor):
1205
1303
  for arg in args
1206
1304
  ]
1207
1305
  prototype = ASTFunction([], arg_types, args_cst, dict())
1208
- generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn,
1306
+ generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
1209
1307
  function_name=fn_name, function_types=self.function_ret_types,
1210
1308
  noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
1211
1309
  options=self.builder.options, codegen_fns=self.builder.codegen_fns,
1212
- module_map=self.builder.module_map)
1310
+ module_map=self.builder.module_map, caller_context=caller_context,
1311
+ is_gluon=self.is_gluon)
1213
1312
  try:
1214
1313
  generator.visit(fn.parse())
1215
1314
  except Exception as e:
1216
1315
  # Wrap the error in the callee with the location of the call.
1316
+ if knobs.compilation.front_end_debugging:
1317
+ raise
1217
1318
  raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
1218
1319
 
1219
1320
  callee_ret_type = generator.ret_type
@@ -1221,28 +1322,30 @@ class CodeGenerator(ast.NodeVisitor):
1221
1322
  else:
1222
1323
  callee_ret_type = self.function_ret_types[fn_name]
1223
1324
  symbol = self.module.get_function(fn_name)
1224
- args_val = [arg.handle for arg in args_val]
1325
+ args_val = flatten_values_to_ir(args_val)
1225
1326
  call_op = self.builder.call(symbol, args_val)
1226
1327
  if callee_ret_type == language.void:
1227
1328
  return None
1228
1329
  handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
1229
1330
  return next(unflatten_ir_values(handles, [callee_ret_type]))
1230
1331
 
1231
- def visit_Call(self, node):
1232
- fn = _unwrap_if_constexpr(self.visit(node.func))
1233
- static_implementation = self.statically_implemented_functions.get(fn)
1234
- if static_implementation is not None:
1235
- return static_implementation(self, node)
1236
-
1237
- kws = dict(self.visit(keyword) for keyword in node.keywords)
1238
- args = [self.visit(arg) for arg in node.args]
1239
- args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1332
+ def call_Function(self, node, fn, args, kws):
1333
+ if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
1334
+ args.insert(0, fn.__self__)
1335
+ fn = fn.__func__
1240
1336
  if isinstance(fn, JITFunction):
1241
1337
  _check_fn_args(node, fn, args)
1242
1338
  return self.call_JitFunction(fn, args, kws)
1243
- if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
1244
- extra_kwargs = {"_builder": self.builder}
1245
- sig = inspect.signature(fn)
1339
+ if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
1340
+ fn, ConstexprFunction):
1341
+ extra_kwargs = dict()
1342
+
1343
+ if isinstance(fn, ConstexprFunction):
1344
+ sig = inspect.signature(fn.__call__)
1345
+ else:
1346
+ sig = inspect.signature(fn)
1347
+ if '_semantic' in sig.parameters:
1348
+ extra_kwargs["_semantic"] = self.semantic
1246
1349
  if '_generator' in sig.parameters:
1247
1350
  extra_kwargs['_generator'] = self
1248
1351
  try:
@@ -1252,43 +1355,125 @@ class CodeGenerator(ast.NodeVisitor):
1252
1355
  ret = language.tuple(ret)
1253
1356
  return ret
1254
1357
  except Exception as e:
1358
+ if knobs.compilation.front_end_debugging:
1359
+ raise
1255
1360
  # Normally when we raise a CompilationError, we raise it as
1256
1361
  # `from None`, because the original fileline from the exception
1257
1362
  # is not relevant (and often points into code_generator.py
1258
1363
  # itself). But when calling a function, we raise as `from e` to
1259
1364
  # preserve the traceback of the original error, which may e.g.
1260
1365
  # be in core.py.
1261
- raise CompilationError(self.jit_fn.src, node, None) from e
1366
+ raise CompilationError(self.jit_fn.src, node, str(e)) from e
1262
1367
 
1263
1368
  if fn in self.builtin_namespace.values():
1264
1369
  args = map(_unwrap_if_constexpr, args)
1265
1370
  ret = fn(*args, **kws)
1266
- return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret
1371
+
1372
+ def wrap_constexpr(x):
1373
+ if _is_triton_value(x):
1374
+ return x
1375
+ return constexpr(x)
1376
+
1377
+ if isinstance(ret, (builtins.tuple, language.tuple)):
1378
+ return _apply_to_tuple_values(ret, wrap_constexpr)
1379
+ return wrap_constexpr(ret)
1380
+
1381
+ def call_Method(self, node, fn, fn_self, args, kws):
1382
+ if isinstance(fn, JITFunction):
1383
+ args.insert(0, fn_self)
1384
+ return self.call_Function(node, fn, args, kws)
1385
+
1386
+ def visit_Call(self, node):
1387
+ fn = _unwrap_if_constexpr(self.visit(node.func))
1388
+ if not isinstance(fn, BoundJITMethod):
1389
+ static_implementation = self.statically_implemented_functions.get(fn)
1390
+ if static_implementation is not None:
1391
+ return static_implementation(self, node)
1392
+
1393
+ mur = getattr(fn, '_must_use_result', False)
1394
+ if mur and getattr(node, '_is_unused', False):
1395
+ error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1396
+ if isinstance(mur, str):
1397
+ error_message.append(mur)
1398
+ raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1399
+
1400
+ kws = dict(self.visit(keyword) for keyword in node.keywords)
1401
+ args = [self.visit(arg) for arg in node.args]
1402
+ args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1403
+
1404
+ return self.call_Function(node, fn, args, kws)
1267
1405
 
1268
1406
  def visit_Constant(self, node):
1269
1407
  return constexpr(node.value)
1270
1408
 
1271
1409
  def visit_BoolOp(self, node: ast.BoolOp):
1272
- if len(node.values) != 2:
1273
- raise self._unsupported(
1274
- node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
1275
- lhs = self.visit(node.values[0])
1276
- rhs = self.visit(node.values[1])
1277
1410
  method_name = self._method_name_for_bool_op.get(type(node.op))
1278
1411
  if method_name is None:
1279
1412
  raise self._unsupported(
1280
1413
  node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
1281
- return self._apply_binary_method(method_name, lhs, rhs)
1414
+
1415
+ nontrivial_values = []
1416
+
1417
+ for subnode in node.values:
1418
+ # we visit the values in order, executing their side-effects
1419
+ # and possibly early-exiting:
1420
+ value = self.visit(subnode)
1421
+ if not _is_triton_tensor(value):
1422
+ # this is a constexpr, so we might be able to short-circuit:
1423
+ bv = bool(value)
1424
+ if (bv is False) and (method_name == "logical_and"):
1425
+ # value is falsey so return that:
1426
+ return value
1427
+ if (bv is True) and (method_name == "logical_or"):
1428
+ # value is truthy so return that:
1429
+ return value
1430
+ # otherwise, our constexpr has no effect on the output of the
1431
+ # expression so we do not append it to nontrivial_values.
1432
+ else:
1433
+ if value.type.is_block():
1434
+ lineno = getattr(node, "lineno", None)
1435
+ if lineno is not None:
1436
+ lineno += self.begin_line
1437
+ warnings.warn_explicit(
1438
+ "Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
1439
+ category=UserWarning,
1440
+ filename=self.file_name,
1441
+ lineno=lineno,
1442
+ source=ast.unparse(node),
1443
+ )
1444
+ # not a constexpr so we must append it:
1445
+ nontrivial_values.append(value)
1446
+
1447
+ if len(nontrivial_values) == 0:
1448
+ # the semantics of a disjunction of falsey values or conjunction
1449
+ # of truthy values is to return the final value:
1450
+ nontrivial_values.append(value)
1451
+
1452
+ while len(nontrivial_values) >= 2:
1453
+ rhs = nontrivial_values.pop()
1454
+ lhs = nontrivial_values.pop()
1455
+ res = self._apply_binary_method(method_name, lhs, rhs)
1456
+ nontrivial_values.append(res)
1457
+
1458
+ assert len(nontrivial_values) == 1
1459
+ return nontrivial_values[0]
1282
1460
 
1283
1461
  _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
1284
1462
 
1285
1463
  def visit_Attribute(self, node):
1286
1464
  lhs = self.visit(node.value)
1287
1465
  if _is_triton_tensor(lhs) and node.attr == "T":
1288
- return semantic.permute(lhs, (1, 0), builder=self.builder)
1289
- return getattr(lhs, node.attr)
1466
+ return self.semantic.permute(lhs, (1, 0))
1467
+ # NOTE: special case ".value" for BC
1468
+ if isinstance(lhs, constexpr) and node.attr not in ("value", "type"):
1469
+ lhs = lhs.value
1470
+ attr = getattr(lhs, node.attr)
1471
+ if _is_triton_value(lhs) and isinstance(attr, JITFunction):
1472
+ return BoundJITMethod(lhs, attr)
1473
+ return attr
1290
1474
 
1291
1475
  def visit_Expr(self, node):
1476
+ node.value._is_unused = True
1292
1477
  ast.NodeVisitor.generic_visit(self, node)
1293
1478
 
1294
1479
  def visit_NoneType(self, node):
@@ -1324,13 +1509,19 @@ class CodeGenerator(ast.NodeVisitor):
1324
1509
  last_loc = self.builder.get_loc()
1325
1510
  self.cur_node = node
1326
1511
  if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
1327
- self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
1512
+ here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
1513
+ if self.name_loc_as_prefix is not None:
1514
+ self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
1515
+ else:
1516
+ self.builder.set_loc(here_loc)
1328
1517
  last_loc = self.builder.get_loc()
1329
1518
  try:
1330
1519
  ret = super().visit(node)
1331
1520
  except CompilationError:
1332
1521
  raise
1333
1522
  except Exception as e:
1523
+ if knobs.compilation.front_end_debugging:
1524
+ raise
1334
1525
  # Wrap the error in a CompilationError which contains the source
1335
1526
  # of the @jit function.
1336
1527
  raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
@@ -1378,16 +1569,29 @@ class CodeGenerator(ast.NodeVisitor):
1378
1569
 
1379
1570
  return ret
1380
1571
 
1572
+ from ..experimental.gluon import language as ttgl
1381
1573
  statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
1382
1574
  language.core.static_assert: execute_static_assert,
1383
1575
  language.core.static_print: static_executor(print),
1576
+ ttgl.static_assert: execute_static_assert,
1577
+ ttgl.static_print: static_executor(print),
1384
1578
  int: static_executor(int),
1385
1579
  len: static_executor(len),
1386
1580
  }
1387
1581
 
1388
1582
 
1389
- def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
1390
- arg_types = list(map(str_to_ty, src.signature.values()))
1583
+ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
1584
+ arg_types = [None] * len(fn.arg_names)
1585
+ const_iter = iter(src.constants.items())
1586
+ kc, vc = next(const_iter, (None, None))
1587
+
1588
+ for i, (ks, v) in enumerate(src.signature.items()):
1589
+ idx = fn.arg_names.index(ks)
1590
+ cexpr = None
1591
+ if kc is not None and kc[0] == i:
1592
+ cexpr = vc
1593
+ kc, vc = next(const_iter, (None, None))
1594
+ arg_types[idx] = str_to_ty(v, cexpr)
1391
1595
  prototype = ASTFunction([], arg_types, src.constants, src.attrs)
1392
1596
  file_name, begin_line = get_jit_fn_file_line(fn)
1393
1597
  # query function representation
@@ -1396,11 +1600,15 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
1396
1600
  constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
1397
1601
  signature = src.signature
1398
1602
  proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
1399
- generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(proxy), jit_fn=fn,
1400
- is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1401
- codegen_fns=codegen_fns, module_map=module_map)
1603
+ generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
1604
+ jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1605
+ codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
1402
1606
  generator.visit(fn.parse())
1403
- ret = generator.module
1607
+ module = generator.module
1404
1608
  # module takes ownership of the context
1405
- ret.context = context
1406
- return ret
1609
+ module.context = context
1610
+ if not module.verify_with_diagnostics():
1611
+ if not fn.is_gluon():
1612
+ print(module)
1613
+ raise RuntimeError("error encountered during parsing")
1614
+ return module