triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +85 -0
- triton/_internal_testing.py +123 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +368 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +512 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +304 -0
- triton/backends/driver.py +48 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +410 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +524 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1303 -0
- triton/compiler/compiler.py +430 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +294 -0
- triton/language/_utils.py +21 -0
- triton/language/core.py +2694 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +13 -0
- triton/language/extra/cuda/_experimental_tma.py +108 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +475 -0
- triton/language/extra/libdevice.py +786 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1796 -0
- triton/language/standard.py +452 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +408 -0
- triton/runtime/build.py +111 -0
- triton/runtime/cache.py +295 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1235 -0
- triton/runtime/jit.py +951 -0
- triton/testing.py +511 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +155 -0
- triton/tools/disasm.py +144 -0
- triton/tools/experimental_descriptor.py +32 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +375 -0
- triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
- triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
- triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
- triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
|
@@ -0,0 +1,1303 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import inspect
|
|
3
|
+
import re
|
|
4
|
+
import sys
|
|
5
|
+
import warnings
|
|
6
|
+
import os
|
|
7
|
+
import textwrap
|
|
8
|
+
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
|
9
|
+
from .. import language
|
|
10
|
+
from .._C.libtriton import ir
|
|
11
|
+
from ..language import constexpr, tensor, str_to_ty
|
|
12
|
+
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value
|
|
13
|
+
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
|
|
14
|
+
# ideally we wouldn't need any runtime component
|
|
15
|
+
from ..runtime import JITFunction
|
|
16
|
+
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
|
17
|
+
from types import ModuleType
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def mangle_ty(ty):
|
|
21
|
+
if ty.is_ptr():
|
|
22
|
+
return 'P' + mangle_ty(ty.element_ty)
|
|
23
|
+
if ty.is_int():
|
|
24
|
+
SIGNED = language.dtype.SIGNEDNESS.SIGNED
|
|
25
|
+
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
|
|
26
|
+
return prefix + str(ty.int_bitwidth)
|
|
27
|
+
if ty.is_floating():
|
|
28
|
+
return str(ty)
|
|
29
|
+
if ty.is_block():
|
|
30
|
+
elt = mangle_ty(ty.scalar)
|
|
31
|
+
shape = '_'.join(map(str, ty.shape))
|
|
32
|
+
return f'{elt}S{shape}S'
|
|
33
|
+
if ty.is_void():
|
|
34
|
+
return 'V'
|
|
35
|
+
raise TypeError(f'Unsupported type {ty}')
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def mangle_fn(name, arg_tys, constants):
|
|
39
|
+
# doesn't mangle ret type, which must be a function of arg tys
|
|
40
|
+
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
|
41
|
+
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
|
|
42
|
+
mangled_constants = mangled_constants.replace('.', '_d_')
|
|
43
|
+
mangled_constants = mangled_constants.replace("'", '_sq_')
|
|
44
|
+
# [ and ] are not allowed in LLVM identifiers
|
|
45
|
+
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
|
|
46
|
+
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
|
47
|
+
return ret
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _is_triton_value(o: Any) -> bool:
|
|
51
|
+
return isinstance(o, _value)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _is_triton_tensor(o: Any) -> bool:
|
|
55
|
+
return isinstance(o, tensor)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _is_constexpr(o: Any) -> bool:
|
|
59
|
+
return isinstance(o, constexpr)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _is_triton_scalar(o: Any) -> bool:
|
|
63
|
+
return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _is_list_like(o: Any) -> bool:
|
|
67
|
+
return isinstance(o, (list, tuple))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _check_fn_args(node, fn, args):
|
|
71
|
+
if fn.noinline:
|
|
72
|
+
for idx, arg in enumerate(args):
|
|
73
|
+
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
|
|
74
|
+
raise UnsupportedLanguageConstruct(
|
|
75
|
+
fn.src, node,
|
|
76
|
+
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class enter_sub_region:
|
|
84
|
+
|
|
85
|
+
def __init__(self, generator):
|
|
86
|
+
self.generator = generator
|
|
87
|
+
|
|
88
|
+
def __enter__(self):
|
|
89
|
+
# record lscope & local_defs in the parent scope
|
|
90
|
+
self.liveins = self.generator.lscope.copy()
|
|
91
|
+
self.prev_defs = self.generator.local_defs.copy()
|
|
92
|
+
self.generator.local_defs = {}
|
|
93
|
+
self.insert_block = self.generator.builder.get_insertion_block()
|
|
94
|
+
self.insert_point = self.generator.builder.get_insertion_point()
|
|
95
|
+
return self.liveins, self.insert_block
|
|
96
|
+
|
|
97
|
+
def __exit__(self, *args, **kwargs):
|
|
98
|
+
self.generator.builder.restore_insertion_point(self.insert_point)
|
|
99
|
+
self.generator.lscope = self.liveins
|
|
100
|
+
self.generator.local_defs = self.prev_defs
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Check if the given syntax node has an "early" return
|
|
104
|
+
class ContainsReturnChecker(ast.NodeVisitor):
|
|
105
|
+
|
|
106
|
+
def __init__(self, gscope):
|
|
107
|
+
self.gscope = gscope
|
|
108
|
+
|
|
109
|
+
def _visit_stmts(self, body) -> bool:
|
|
110
|
+
return any(self.visit(s) for s in body)
|
|
111
|
+
|
|
112
|
+
def _visit_function(self, fn) -> bool:
|
|
113
|
+
# Currently we only support JITFunctions defined in the global scope
|
|
114
|
+
if isinstance(fn, JITFunction) and not fn.noinline:
|
|
115
|
+
fn_node = fn.parse()
|
|
116
|
+
return ContainsReturnChecker(self.gscope).visit(fn_node)
|
|
117
|
+
return False
|
|
118
|
+
|
|
119
|
+
def generic_visit(self, node) -> bool:
|
|
120
|
+
ret = False
|
|
121
|
+
for _, value in ast.iter_fields(node):
|
|
122
|
+
if isinstance(value, list):
|
|
123
|
+
for item in value:
|
|
124
|
+
if isinstance(item, ast.AST):
|
|
125
|
+
ret = ret or self.visit(item)
|
|
126
|
+
elif isinstance(value, ast.AST):
|
|
127
|
+
ret = ret or self.visit(value)
|
|
128
|
+
return ret
|
|
129
|
+
|
|
130
|
+
def visit_Attribute(self, node: ast.Attribute) -> bool:
|
|
131
|
+
# If the left part is a name, it's possible that
|
|
132
|
+
# we call triton native function or a jit function from another module.
|
|
133
|
+
# If the left part is not a name, it must return a tensor or a constexpr
|
|
134
|
+
# whose methods do not contain return statements
|
|
135
|
+
# e.g., (tl.load(x)).to(y)
|
|
136
|
+
# So we only check if the expressions within value have return or not
|
|
137
|
+
if isinstance(node.value, ast.Name):
|
|
138
|
+
if node.value.id in self.gscope:
|
|
139
|
+
value = self.gscope[node.value.id]
|
|
140
|
+
fn = getattr(value, node.attr)
|
|
141
|
+
return self._visit_function(fn)
|
|
142
|
+
return False
|
|
143
|
+
return self.visit(node.value)
|
|
144
|
+
|
|
145
|
+
def visit_Name(self, node: ast.Name) -> bool:
|
|
146
|
+
if type(node.ctx) is ast.Store:
|
|
147
|
+
return False
|
|
148
|
+
if node.id in self.gscope:
|
|
149
|
+
fn = self.gscope[node.id]
|
|
150
|
+
return self._visit_function(fn)
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
def visit_Return(self, node: ast.Return) -> bool:
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
def visit_Assign(self, node: ast.Assign) -> bool:
|
|
157
|
+
# There couldn't be an early return
|
|
158
|
+
# x = ...
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
def visit_AugAssign(self, node: ast.AugAssign) -> bool:
|
|
162
|
+
# There couldn't be an early return
|
|
163
|
+
# x += ...
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
def visit_Module(self, node: ast.Module) -> bool:
|
|
167
|
+
return self._visit_stmts(node.body)
|
|
168
|
+
|
|
169
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
|
|
170
|
+
return self._visit_stmts(node.body)
|
|
171
|
+
|
|
172
|
+
def visit_If(self, node: ast.If) -> bool:
|
|
173
|
+
# TODO: optimize the following case in which we actually don't have
|
|
174
|
+
# a return when static_cond is false:
|
|
175
|
+
# if dynamic_cond
|
|
176
|
+
# if static_cond
|
|
177
|
+
# func_with_return
|
|
178
|
+
# else
|
|
179
|
+
# func_without_return
|
|
180
|
+
ret = self._visit_stmts(node.body)
|
|
181
|
+
if node.orelse:
|
|
182
|
+
ret = ret or self._visit_stmts(node.orelse)
|
|
183
|
+
return ret
|
|
184
|
+
|
|
185
|
+
def visit_IfExp(self, node: ast.IfExp) -> bool:
|
|
186
|
+
return self.visit(node.body) or self.visit(node.orelse)
|
|
187
|
+
|
|
188
|
+
def visit_Call(self, node: ast.Call) -> bool:
|
|
189
|
+
return self.visit(node.func)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class CodeGenerator(ast.NodeVisitor):
|
|
193
|
+
|
|
194
|
+
def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
|
|
195
|
+
codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
|
196
|
+
noinline=False, file_name: Optional[str] = None, begin_line=0):
|
|
197
|
+
self.context = context
|
|
198
|
+
self.builder = ir.builder(context)
|
|
199
|
+
self.file_name = file_name
|
|
200
|
+
# node.lineno starts from 1, so we need to subtract 1
|
|
201
|
+
self.begin_line = begin_line - 1
|
|
202
|
+
self.builder.set_loc(file_name, begin_line, 0)
|
|
203
|
+
self.builder.options = options
|
|
204
|
+
# dict of functions provided by the backend. Below are the list of possible functions:
|
|
205
|
+
# Convert custom types not natively supported on HW.
|
|
206
|
+
# convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
|
|
207
|
+
self.builder.codegen_fns = codegen_fns
|
|
208
|
+
self.builder.module_map = {} if module_map is None else module_map
|
|
209
|
+
self.module = self.builder.create_module() if module is None else module
|
|
210
|
+
self.function_ret_types = {} if function_types is None else function_types
|
|
211
|
+
self.prototype = prototype
|
|
212
|
+
|
|
213
|
+
self.gscope = {}
|
|
214
|
+
for k, v in gscope.items():
|
|
215
|
+
if isinstance(v, ModuleType):
|
|
216
|
+
self.gscope[k] = module_map.get(v.__name__, v)
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
module_name = getattr(v, "__module__", "")
|
|
220
|
+
if module_name in module_map:
|
|
221
|
+
self.gscope[k] = getattr(module_map[module_name], v.__name__)
|
|
222
|
+
else:
|
|
223
|
+
self.gscope[k] = v
|
|
224
|
+
|
|
225
|
+
self.lscope = {}
|
|
226
|
+
self.attributes = attributes
|
|
227
|
+
self.constants = constants
|
|
228
|
+
self.jit_fn = jit_fn
|
|
229
|
+
self.function_name = function_name
|
|
230
|
+
self.is_kernel = is_kernel
|
|
231
|
+
self.cur_node = None
|
|
232
|
+
self.noinline = noinline
|
|
233
|
+
self.scf_stack = []
|
|
234
|
+
self.ret_type = None
|
|
235
|
+
# SSA-construction
|
|
236
|
+
# name => language.tensor
|
|
237
|
+
self.local_defs: Dict[str, tensor] = {}
|
|
238
|
+
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
|
|
239
|
+
self.fn = None
|
|
240
|
+
# Are we currently visiting an ast.arg's default value? These have some
|
|
241
|
+
# special handling.
|
|
242
|
+
self.visiting_arg_default_value = False
|
|
243
|
+
|
|
244
|
+
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)}
|
|
245
|
+
builtin_namespace.update((
|
|
246
|
+
('print', language.core.device_print),
|
|
247
|
+
('min', language.minimum),
|
|
248
|
+
('max', language.maximum),
|
|
249
|
+
))
|
|
250
|
+
|
|
251
|
+
def _unsupported(self, node, message):
|
|
252
|
+
return UnsupportedLanguageConstruct(self.jit_fn.src, node, message)
|
|
253
|
+
|
|
254
|
+
def _is_constexpr_global(self, name):
|
|
255
|
+
absent_marker = object()
|
|
256
|
+
val = self.gscope.get(name, absent_marker)
|
|
257
|
+
if val is absent_marker:
|
|
258
|
+
return False
|
|
259
|
+
|
|
260
|
+
if _is_constexpr(val):
|
|
261
|
+
return True
|
|
262
|
+
|
|
263
|
+
if a := self.gscope.get("__annotations__", {}).get(name):
|
|
264
|
+
return _normalize_ty(a) == "constexpr"
|
|
265
|
+
|
|
266
|
+
return False
|
|
267
|
+
|
|
268
|
+
def _define_name_lookup(self):
|
|
269
|
+
|
|
270
|
+
def local_lookup(name: str, absent):
|
|
271
|
+
# this needs to be re-fetched from `self` every time, because it gets switched occasionally
|
|
272
|
+
return self.lscope.get(name, absent)
|
|
273
|
+
|
|
274
|
+
def global_lookup(name: str, absent):
|
|
275
|
+
val = self.gscope.get(name, absent)
|
|
276
|
+
# The high-level rule is that only constexpr globals are allowed.
|
|
277
|
+
# But actually a bunch of other things, such as module imports, are
|
|
278
|
+
# technically Python globals. We have to allow these too!
|
|
279
|
+
if any([
|
|
280
|
+
val is absent, name in self.builtin_namespace, #
|
|
281
|
+
type(val) is ModuleType, #
|
|
282
|
+
isinstance(val, JITFunction), #
|
|
283
|
+
getattr(val, "__triton_builtin__", False), #
|
|
284
|
+
getattr(val, "__module__", "").startswith("triton.language"), #
|
|
285
|
+
isinstance(val, language.dtype), #
|
|
286
|
+
self._is_constexpr_global(name), #
|
|
287
|
+
# Allow accesses to globals while visiting an ast.arg
|
|
288
|
+
# because you should be able to do
|
|
289
|
+
# @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
|
|
290
|
+
self.visiting_arg_default_value, #
|
|
291
|
+
os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"
|
|
292
|
+
]):
|
|
293
|
+
return val
|
|
294
|
+
raise NameError(
|
|
295
|
+
textwrap.dedent(f"""\
|
|
296
|
+
Cannot access global variable {name} from within @jit'ed
|
|
297
|
+
function. Triton kernels can only access global variables that
|
|
298
|
+
are annotated as constexpr (`x: triton.language.constexpr = 42`
|
|
299
|
+
or `x = triton.language.constexpr(42)`). Alternatively, set the
|
|
300
|
+
envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
|
|
301
|
+
promise to support this forever.""").replace("\n", " "))
|
|
302
|
+
|
|
303
|
+
absent_marker = object()
|
|
304
|
+
|
|
305
|
+
def name_lookup(name: str) -> Any:
|
|
306
|
+
absent = absent_marker
|
|
307
|
+
for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get:
|
|
308
|
+
value = lookup_function(name, absent)
|
|
309
|
+
if value is not absent:
|
|
310
|
+
return value
|
|
311
|
+
raise NameError(f'{name} is not defined')
|
|
312
|
+
|
|
313
|
+
return name_lookup
|
|
314
|
+
|
|
315
|
+
def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
|
|
316
|
+
''' This function:
|
|
317
|
+
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
|
318
|
+
1. record local defined name (FIXME: should consider control flow)
|
|
319
|
+
2. store tensor in self.lvalue
|
|
320
|
+
'''
|
|
321
|
+
self.lscope[name] = value
|
|
322
|
+
self.local_defs[name] = value
|
|
323
|
+
|
|
324
|
+
def _get_insertion_point_and_loc(self):
|
|
325
|
+
# XXX: this is a hack to get the location of the insertion point.
|
|
326
|
+
# The insertion point's location could be invalid sometimes,
|
|
327
|
+
# so we need to explicitly set the location
|
|
328
|
+
loc = self.builder.get_loc()
|
|
329
|
+
ip = self.builder.get_insertion_point()
|
|
330
|
+
return ip, loc
|
|
331
|
+
|
|
332
|
+
def _set_insertion_point_and_loc(self, ip, loc):
|
|
333
|
+
self.builder.restore_insertion_point(ip)
|
|
334
|
+
self.builder.set_loc(loc)
|
|
335
|
+
|
|
336
|
+
#
|
|
337
|
+
# AST visitor
|
|
338
|
+
#
|
|
339
|
+
def visit_compound_statement(self, stmts):
|
|
340
|
+
# Ensure that stmts is iterable
|
|
341
|
+
if not _is_list_like(stmts):
|
|
342
|
+
stmts = [stmts]
|
|
343
|
+
for stmt in stmts:
|
|
344
|
+
self.visit(stmt)
|
|
345
|
+
|
|
346
|
+
# Stop parsing as soon as we hit a `return` statement; everything
|
|
347
|
+
# after this is dead code.
|
|
348
|
+
if isinstance(stmt, ast.Return):
|
|
349
|
+
break
|
|
350
|
+
|
|
351
|
+
def visit_Module(self, node):
|
|
352
|
+
ast.NodeVisitor.generic_visit(self, node)
|
|
353
|
+
|
|
354
|
+
def visit_List(self, node):
|
|
355
|
+
ctx = self.visit(node.ctx)
|
|
356
|
+
assert ctx is None
|
|
357
|
+
elts = [self.visit(elt) for elt in node.elts]
|
|
358
|
+
return elts
|
|
359
|
+
|
|
360
|
+
# By design, only non-kernel functions can return
|
|
361
|
+
def visit_Return(self, node):
|
|
362
|
+
ret_value = self.visit(node.value)
|
|
363
|
+
if ret_value is None:
|
|
364
|
+
self.builder.ret([])
|
|
365
|
+
ret_ty = language.void
|
|
366
|
+
elif isinstance(ret_value, tuple):
|
|
367
|
+
ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value]
|
|
368
|
+
ret_types = [v.type for v in ret_values]
|
|
369
|
+
self.builder.ret([v.handle for v in ret_values])
|
|
370
|
+
ret_ty = tuple(ret_types)
|
|
371
|
+
else:
|
|
372
|
+
ret = language.semantic.to_tensor(ret_value, self.builder)
|
|
373
|
+
self.builder.ret([ret.handle])
|
|
374
|
+
ret_ty = ret.type
|
|
375
|
+
|
|
376
|
+
if self.ret_type is None:
|
|
377
|
+
self.ret_type = ret_ty
|
|
378
|
+
elif self.ret_type != ret_ty:
|
|
379
|
+
raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}')
|
|
380
|
+
|
|
381
|
+
# A return op must always terminate the basic block, so we create a dead
|
|
382
|
+
# basic block in case there are any ops after the return.
|
|
383
|
+
post_ret_block = self.builder.create_block()
|
|
384
|
+
self.builder.set_insertion_point_to_end(post_ret_block)
|
|
385
|
+
|
|
386
|
+
def visit_FunctionDef(self, node):
|
|
387
|
+
arg_names, kwarg_names = self.visit(node.args)
|
|
388
|
+
if self.fn:
|
|
389
|
+
raise self._unsupported(node, "nested function definition is not supported.")
|
|
390
|
+
# initialize defaults
|
|
391
|
+
for i, default_value in enumerate(node.args.defaults[::-1]):
|
|
392
|
+
arg_node = node.args.args[-i - 1]
|
|
393
|
+
annotation = arg_node.annotation
|
|
394
|
+
name = arg_node.arg
|
|
395
|
+
st_target = ast.Name(id=name, ctx=ast.Store())
|
|
396
|
+
if annotation is None:
|
|
397
|
+
init_node = ast.Assign(targets=[st_target], value=default_value)
|
|
398
|
+
else:
|
|
399
|
+
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
assert not self.visiting_arg_default_value
|
|
403
|
+
self.visiting_arg_default_value = True
|
|
404
|
+
self.visit(init_node)
|
|
405
|
+
finally:
|
|
406
|
+
self.visiting_arg_default_value = False
|
|
407
|
+
|
|
408
|
+
# initialize function
|
|
409
|
+
visibility = "public" if self.is_kernel else "private"
|
|
410
|
+
self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
|
|
411
|
+
self.prototype.to_ir(self.builder), visibility, self.noinline)
|
|
412
|
+
self.module.push_back(self.fn)
|
|
413
|
+
entry = self.fn.add_entry_block()
|
|
414
|
+
arg_values = []
|
|
415
|
+
idx = 0
|
|
416
|
+
for i in range(len(arg_names)):
|
|
417
|
+
if i in self.constants:
|
|
418
|
+
cst = self.constants[i]
|
|
419
|
+
if not _is_constexpr(cst):
|
|
420
|
+
cst = constexpr(self.constants[i])
|
|
421
|
+
arg_values.append(cst)
|
|
422
|
+
continue
|
|
423
|
+
else:
|
|
424
|
+
if i in self.attributes:
|
|
425
|
+
for name, value in self.attributes[i]:
|
|
426
|
+
self.fn.set_arg_attr(idx, name, value)
|
|
427
|
+
|
|
428
|
+
# Mark this argument as a pass-by-value TMA descriptor (nvidia)
|
|
429
|
+
if isinstance(self.prototype.param_types[idx], nv_tma_desc_type):
|
|
430
|
+
self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1)
|
|
431
|
+
|
|
432
|
+
arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
|
|
433
|
+
idx += 1
|
|
434
|
+
|
|
435
|
+
insert_pt = self.builder.get_insertion_block()
|
|
436
|
+
for arg_name, arg_value in zip(arg_names, arg_values):
|
|
437
|
+
self.set_value(arg_name, arg_value)
|
|
438
|
+
self.builder.set_insertion_point_to_start(entry)
|
|
439
|
+
# visit function body
|
|
440
|
+
self.visit_compound_statement(node.body)
|
|
441
|
+
|
|
442
|
+
# finalize function
|
|
443
|
+
assert not self.builder.get_insertion_block().has_terminator()
|
|
444
|
+
if self.ret_type is None or self.ret_type == language.void:
|
|
445
|
+
self.ret_type = language.void
|
|
446
|
+
self.builder.ret([])
|
|
447
|
+
else:
|
|
448
|
+
self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type]
|
|
449
|
+
self.fn.reset_type(self.prototype.to_ir(self.builder))
|
|
450
|
+
self.builder.ret([
|
|
451
|
+
self.builder.create_poison(ty.to_ir(self.builder))
|
|
452
|
+
for ty in self.prototype.ret_types
|
|
453
|
+
if self.ret_type is not None
|
|
454
|
+
])
|
|
455
|
+
self.fn.finalize()
|
|
456
|
+
|
|
457
|
+
if insert_pt:
|
|
458
|
+
self.builder.set_insertion_point_to_end(insert_pt)
|
|
459
|
+
|
|
460
|
+
def visit_arguments(self, node):
|
|
461
|
+
arg_names = []
|
|
462
|
+
for arg in node.args:
|
|
463
|
+
arg_names += [self.visit(arg)]
|
|
464
|
+
kwarg_names = self.visit(node.kwarg)
|
|
465
|
+
return arg_names, kwarg_names
|
|
466
|
+
|
|
467
|
+
def visit_arg(self, node):
|
|
468
|
+
ast.NodeVisitor.generic_visit(self, node)
|
|
469
|
+
return node.arg
|
|
470
|
+
|
|
471
|
+
def visit_AnnAssign(self, node):
|
|
472
|
+
# extract attributes
|
|
473
|
+
annotation = self.visit(node.annotation)
|
|
474
|
+
target = self.visit(node.target)
|
|
475
|
+
value = self.visit(node.value)
|
|
476
|
+
# constexpr
|
|
477
|
+
if annotation == constexpr:
|
|
478
|
+
if target in self.lscope:
|
|
479
|
+
raise ValueError(f'{target} is already defined.'
|
|
480
|
+
f' constexpr cannot be reassigned.')
|
|
481
|
+
if not _is_constexpr(value):
|
|
482
|
+
value = constexpr(value)
|
|
483
|
+
self.lscope[target] = value
|
|
484
|
+
return self.lscope[target]
|
|
485
|
+
# default: call visit_Assign
|
|
486
|
+
return self.visit_Assign(node)
|
|
487
|
+
|
|
488
|
+
def visit_Assign(self, node):
|
|
489
|
+
_names = []
|
|
490
|
+
if isinstance(node, ast.AnnAssign):
|
|
491
|
+
_names += [self.visit(node.target)]
|
|
492
|
+
else:
|
|
493
|
+
for target in node.targets:
|
|
494
|
+
_names += [self.visit(target)]
|
|
495
|
+
if len(_names) > 1:
|
|
496
|
+
raise self._unsupported(node, "simultaneous multiple assignment is not supported.")
|
|
497
|
+
names = _names[0]
|
|
498
|
+
values = self.visit(node.value)
|
|
499
|
+
if not _is_list_like(names):
|
|
500
|
+
names = [names]
|
|
501
|
+
if not _is_list_like(values):
|
|
502
|
+
values = [values]
|
|
503
|
+
native_nontensor_types = (language.dtype, )
|
|
504
|
+
for name, value in zip(names, values):
|
|
505
|
+
# by default, constexpr are assigned into python variable
|
|
506
|
+
value = _unwrap_if_constexpr(value)
|
|
507
|
+
if value is not None and \
|
|
508
|
+
not _is_triton_value(value) and \
|
|
509
|
+
not isinstance(value, native_nontensor_types):
|
|
510
|
+
value = language.semantic.to_tensor(value, self.builder)
|
|
511
|
+
self.set_value(name, value)
|
|
512
|
+
|
|
513
|
+
def visit_AugAssign(self, node):
|
|
514
|
+
name = node.target.id
|
|
515
|
+
lhs = ast.Name(id=name, ctx=ast.Load())
|
|
516
|
+
rhs = ast.BinOp(lhs, node.op, node.value)
|
|
517
|
+
assign = ast.Assign(targets=[node.target], value=rhs)
|
|
518
|
+
self.visit(assign)
|
|
519
|
+
return self.dereference_name(name)
|
|
520
|
+
|
|
521
|
+
def visit_Name(self, node):
|
|
522
|
+
if type(node.ctx) is ast.Store:
|
|
523
|
+
return node.id
|
|
524
|
+
return self.dereference_name(node.id)
|
|
525
|
+
|
|
526
|
+
def visit_Store(self, node):
|
|
527
|
+
ast.NodeVisitor.generic_visit(self, node)
|
|
528
|
+
|
|
529
|
+
def visit_Load(self, node):
|
|
530
|
+
ast.NodeVisitor.generic_visit(self, node)
|
|
531
|
+
|
|
532
|
+
def visit_Tuple(self, node):
|
|
533
|
+
args = [self.visit(x) for x in node.elts]
|
|
534
|
+
return tuple(args)
|
|
535
|
+
|
|
536
|
+
def _apply_binary_method(self, method_name, lhs, rhs):
|
|
537
|
+
# TODO: raise something meaningful if getattr fails below, esp for reverse method
|
|
538
|
+
if _is_triton_tensor(lhs):
|
|
539
|
+
return getattr(lhs, method_name)(rhs, _builder=self.builder)
|
|
540
|
+
if _is_triton_tensor(rhs):
|
|
541
|
+
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
|
|
542
|
+
return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder)
|
|
543
|
+
return getattr(lhs, method_name)(rhs)
|
|
544
|
+
|
|
545
|
+
def visit_BinOp(self, node):
|
|
546
|
+
lhs = self.visit(node.left)
|
|
547
|
+
rhs = self.visit(node.right)
|
|
548
|
+
method_name = self._method_name_for_bin_op.get(type(node.op))
|
|
549
|
+
if method_name is None:
|
|
550
|
+
raise self._unsupported(node,
|
|
551
|
+
"AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
|
552
|
+
return self._apply_binary_method(method_name, lhs, rhs)
|
|
553
|
+
|
|
554
|
+
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
|
|
555
|
+
ast.Add: '__add__',
|
|
556
|
+
ast.Sub: '__sub__',
|
|
557
|
+
ast.Mult: '__mul__',
|
|
558
|
+
ast.Div: '__truediv__',
|
|
559
|
+
ast.FloorDiv: '__floordiv__',
|
|
560
|
+
ast.Mod: '__mod__',
|
|
561
|
+
ast.Pow: '__pow__',
|
|
562
|
+
ast.LShift: '__lshift__',
|
|
563
|
+
ast.RShift: '__rshift__',
|
|
564
|
+
ast.BitAnd: '__and__',
|
|
565
|
+
ast.BitOr: '__or__',
|
|
566
|
+
ast.BitXor: '__xor__',
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
|
|
570
|
+
# then block
|
|
571
|
+
self.builder.set_insertion_point_to_start(then_block)
|
|
572
|
+
self.visit_compound_statement(node.body)
|
|
573
|
+
then_block = self.builder.get_insertion_block()
|
|
574
|
+
then_defs = self.local_defs.copy()
|
|
575
|
+
# else block
|
|
576
|
+
else_defs = {}
|
|
577
|
+
if node.orelse:
|
|
578
|
+
self.builder.set_insertion_point_to_start(else_block)
|
|
579
|
+
self.lscope = liveins.copy()
|
|
580
|
+
self.local_defs = {}
|
|
581
|
+
self.visit_compound_statement(node.orelse)
|
|
582
|
+
else_defs = self.local_defs.copy()
|
|
583
|
+
else_block = self.builder.get_insertion_block()
|
|
584
|
+
|
|
585
|
+
# update block arguments
|
|
586
|
+
names = []
|
|
587
|
+
ret_types = []
|
|
588
|
+
ir_ret_types = []
|
|
589
|
+
# variables in livein whose value is updated in `if`
|
|
590
|
+
for name in liveins:
|
|
591
|
+
# check type
|
|
592
|
+
for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
|
|
593
|
+
if name in defs:
|
|
594
|
+
assert defs[name].type == liveins[name].type, \
|
|
595
|
+
f'initial value for `{name}` is of type {liveins[name].type}, '\
|
|
596
|
+
f'but the {block_name} block redefines it as {defs[name].type}'
|
|
597
|
+
if name in then_defs or name in else_defs:
|
|
598
|
+
names.append(name)
|
|
599
|
+
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
|
|
600
|
+
ir_ret_types.append(then_defs[name].handle.get_type() if name in
|
|
601
|
+
then_defs else else_defs[name].handle.get_type())
|
|
602
|
+
# variable defined in then but not in else
|
|
603
|
+
if name in then_defs and name not in else_defs:
|
|
604
|
+
else_defs[name] = liveins[name]
|
|
605
|
+
# variable defined in else but not in then
|
|
606
|
+
if name in else_defs and name not in then_defs:
|
|
607
|
+
then_defs[name] = liveins[name]
|
|
608
|
+
# variables that are both in then and else but not in liveins
|
|
609
|
+
# TODO: could probably be cleaned up
|
|
610
|
+
for name in sorted(then_defs.keys() & else_defs.keys()):
|
|
611
|
+
if name in names:
|
|
612
|
+
continue
|
|
613
|
+
then_ty = then_defs[name].type
|
|
614
|
+
else_ty = else_defs[name].type
|
|
615
|
+
assert then_ty == else_ty, \
|
|
616
|
+
f'Mismatched type for {name} between then block ({then_ty}) '\
|
|
617
|
+
f'and else block ({else_ty})'
|
|
618
|
+
names.append(name)
|
|
619
|
+
ret_types.append(then_ty)
|
|
620
|
+
ir_ret_types.append(then_defs[name].handle.get_type())
|
|
621
|
+
|
|
622
|
+
return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types
|
|
623
|
+
|
|
624
|
+
def visit_if_top_level(self, cond, node):
|
|
625
|
+
with enter_sub_region(self) as sr:
|
|
626
|
+
liveins, ip_block = sr
|
|
627
|
+
then_block = self.builder.create_block()
|
|
628
|
+
else_block = self.builder.create_block()
|
|
629
|
+
# create branch
|
|
630
|
+
self.builder.set_insertion_point_to_end(ip_block)
|
|
631
|
+
self.builder.create_cond_branch(cond.handle, then_block, else_block)
|
|
632
|
+
# visit then and else blocks
|
|
633
|
+
then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \
|
|
634
|
+
self.visit_then_else_blocks(node, liveins, then_block, else_block)
|
|
635
|
+
# create basic-block after conditional
|
|
636
|
+
endif_block = self.builder.create_block()
|
|
637
|
+
# then terminator
|
|
638
|
+
self.builder.set_insertion_point_to_end(then_block)
|
|
639
|
+
assert not then_block.has_terminator(), f"{then_block}"
|
|
640
|
+
self.builder.create_branch(endif_block, [then_defs[n].handle for n in names])
|
|
641
|
+
# else terminator
|
|
642
|
+
self.builder.set_insertion_point_to_end(else_block)
|
|
643
|
+
assert not else_block.has_terminator(), f"{else_block}"
|
|
644
|
+
self.builder.create_branch(endif_block, [else_defs[n].handle for n in names])
|
|
645
|
+
for ty in ir_ret_types:
|
|
646
|
+
endif_block.add_argument(ty)
|
|
647
|
+
|
|
648
|
+
# change block
|
|
649
|
+
self.builder.set_insertion_point_to_start(endif_block)
|
|
650
|
+
# update value
|
|
651
|
+
for i, name in enumerate(names):
|
|
652
|
+
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
|
|
653
|
+
self.set_value(name, new_tensor)
|
|
654
|
+
|
|
655
|
+
# TODO: refactor
|
|
656
|
+
def visit_if_scf(self, cond, node):
|
|
657
|
+
with enter_sub_region(self) as sr:
|
|
658
|
+
liveins, _ = sr
|
|
659
|
+
ip, last_loc = self._get_insertion_point_and_loc()
|
|
660
|
+
then_block = self.builder.create_block()
|
|
661
|
+
else_block = self.builder.create_block() if node.orelse else None
|
|
662
|
+
then_defs, else_defs, then_block, else_block, names, ret_types, _ = \
|
|
663
|
+
self.visit_then_else_blocks(node, liveins, then_block, else_block)
|
|
664
|
+
# create if op
|
|
665
|
+
self._set_insertion_point_and_loc(ip, last_loc)
|
|
666
|
+
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
|
667
|
+
then_block.merge_block_before(if_op.get_then_block())
|
|
668
|
+
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
|
669
|
+
if len(names) > 0:
|
|
670
|
+
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
|
671
|
+
if not node.orelse:
|
|
672
|
+
else_block = if_op.get_else_block()
|
|
673
|
+
else:
|
|
674
|
+
else_block.merge_block_before(if_op.get_else_block())
|
|
675
|
+
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
|
676
|
+
if len(names) > 0:
|
|
677
|
+
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
|
678
|
+
# update values
|
|
679
|
+
for i, name in enumerate(names):
|
|
680
|
+
new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i])
|
|
681
|
+
self.set_value(name, new_tensor)
|
|
682
|
+
|
|
683
|
+
def visit_If(self, node):
|
|
684
|
+
cond = self.visit(node.test)
|
|
685
|
+
|
|
686
|
+
if _is_triton_tensor(cond):
|
|
687
|
+
cond = cond.to(language.int1, _builder=self.builder)
|
|
688
|
+
contains_return = ContainsReturnChecker(self.gscope).visit(node)
|
|
689
|
+
if contains_return:
|
|
690
|
+
if self.scf_stack:
|
|
691
|
+
raise self._unsupported(
|
|
692
|
+
node, "Cannot have `return` statements inside `while` or `for` statements in triton "
|
|
693
|
+
"(note that this also applies to `return` statements that are inside functions "
|
|
694
|
+
"transitively called from within `while`/`for` statements)")
|
|
695
|
+
self.visit_if_top_level(cond, node)
|
|
696
|
+
else:
|
|
697
|
+
self.visit_if_scf(cond, node)
|
|
698
|
+
else:
|
|
699
|
+
cond = _unwrap_if_constexpr(cond)
|
|
700
|
+
# not isinstance - we insist the real thing, no subclasses and no ducks
|
|
701
|
+
if type(cond) not in _condition_types:
|
|
702
|
+
raise self._unsupported(
|
|
703
|
+
node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
|
704
|
+
', '.join(_.__name__ for _ in _condition_types),
|
|
705
|
+
type(cond).__name__))
|
|
706
|
+
|
|
707
|
+
active_block = node.body if cond else node.orelse
|
|
708
|
+
self.visit_compound_statement(active_block)
|
|
709
|
+
|
|
710
|
+
def visit_IfExp(self, node):
|
|
711
|
+
cond = self.visit(node.test)
|
|
712
|
+
if _is_triton_tensor(cond):
|
|
713
|
+
cond = cond.to(language.int1, _builder=self.builder)
|
|
714
|
+
# TODO: Deal w/ more complicated return types (e.g tuple)
|
|
715
|
+
with enter_sub_region(self):
|
|
716
|
+
ip, last_loc = self._get_insertion_point_and_loc()
|
|
717
|
+
|
|
718
|
+
then_block = self.builder.create_block()
|
|
719
|
+
self.builder.set_insertion_point_to_start(then_block)
|
|
720
|
+
then_val = language.semantic.to_tensor(self.visit(node.body), self.builder)
|
|
721
|
+
then_block = self.builder.get_insertion_block()
|
|
722
|
+
|
|
723
|
+
else_block = self.builder.create_block()
|
|
724
|
+
self.builder.set_insertion_point_to_start(else_block)
|
|
725
|
+
# do not need to reset lscope since
|
|
726
|
+
# ternary expressions cannot define new variables
|
|
727
|
+
else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder)
|
|
728
|
+
else_block = self.builder.get_insertion_block()
|
|
729
|
+
|
|
730
|
+
self._set_insertion_point_and_loc(ip, last_loc)
|
|
731
|
+
|
|
732
|
+
assert then_val.type == else_val.type, \
|
|
733
|
+
f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
|
|
734
|
+
ret_type = then_val.type
|
|
735
|
+
|
|
736
|
+
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
|
|
737
|
+
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
|
|
738
|
+
then_block.merge_block_before(if_op.get_then_block())
|
|
739
|
+
if ret_type_ir:
|
|
740
|
+
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
|
741
|
+
self.builder.create_yield_op([then_val.handle])
|
|
742
|
+
|
|
743
|
+
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
|
744
|
+
else_block.merge_block_before(if_op.get_else_block())
|
|
745
|
+
if ret_type_ir:
|
|
746
|
+
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
|
747
|
+
self.builder.create_yield_op([else_val.handle])
|
|
748
|
+
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
|
|
749
|
+
else:
|
|
750
|
+
cond = _unwrap_if_constexpr(cond)
|
|
751
|
+
|
|
752
|
+
# not isinstance - we insist the real thing, no subclasses and no ducks
|
|
753
|
+
if type(cond) not in _condition_types:
|
|
754
|
+
raise self._unsupported(
|
|
755
|
+
node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
|
756
|
+
', '.join(_.__name__ for _ in _condition_types),
|
|
757
|
+
type(cond).__name__))
|
|
758
|
+
if cond:
|
|
759
|
+
return self.visit(node.body)
|
|
760
|
+
else:
|
|
761
|
+
return self.visit(node.orelse)
|
|
762
|
+
|
|
763
|
+
def visit_Pass(self, node):
|
|
764
|
+
pass
|
|
765
|
+
|
|
766
|
+
def visit_Compare(self, node):
|
|
767
|
+
if not (len(node.comparators) == 1 and len(node.ops) == 1):
|
|
768
|
+
raise self._unsupported(node, "simultaneous multiple comparison is not supported")
|
|
769
|
+
lhs = self.visit(node.left)
|
|
770
|
+
rhs = self.visit(node.comparators[0])
|
|
771
|
+
lhs_value = _unwrap_if_constexpr(lhs)
|
|
772
|
+
rhs_value = _unwrap_if_constexpr(rhs)
|
|
773
|
+
if type(node.ops[0]) is ast.Is:
|
|
774
|
+
return constexpr(lhs_value is rhs_value)
|
|
775
|
+
if type(node.ops[0]) is ast.IsNot:
|
|
776
|
+
return constexpr(lhs_value is not rhs_value)
|
|
777
|
+
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
|
778
|
+
if method_name is None:
|
|
779
|
+
raise self._unsupported(
|
|
780
|
+
node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
|
781
|
+
return self._apply_binary_method(method_name, lhs, rhs)
|
|
782
|
+
|
|
783
|
+
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
|
|
784
|
+
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
def visit_UnaryOp(self, node):
|
|
788
|
+
operand = self.visit(node.operand)
|
|
789
|
+
fn = self._method_name_for_unary_op.get(type(node.op))
|
|
790
|
+
if fn is None:
|
|
791
|
+
raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
|
|
792
|
+
if _is_triton_tensor(operand):
|
|
793
|
+
return getattr(operand, fn)(_builder=self.builder)
|
|
794
|
+
try:
|
|
795
|
+
return getattr(operand, fn)()
|
|
796
|
+
except AttributeError:
|
|
797
|
+
raise self._unsupported(
|
|
798
|
+
node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
|
|
799
|
+
|
|
800
|
+
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
|
|
801
|
+
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
def _verify_loop_carried_variable(self, name, loop_val, live_val):
|
|
805
|
+
assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
|
|
806
|
+
assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
|
|
807
|
+
assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type'
|
|
808
|
+
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
|
|
809
|
+
f'Loop-carried variable {name} has initial type {live_val.type} '\
|
|
810
|
+
f'but is re-assigned to {loop_val.type} in loop! '\
|
|
811
|
+
f'Please make sure that the type stays consistent.'
|
|
812
|
+
|
|
813
|
+
def visit_While(self, node):
|
|
814
|
+
with enter_sub_region(self) as sr:
|
|
815
|
+
liveins, insert_block = sr
|
|
816
|
+
ip, last_loc = self._get_insertion_point_and_loc()
|
|
817
|
+
|
|
818
|
+
# loop body (the after region)
|
|
819
|
+
# loop_block = self.builder.create_block()
|
|
820
|
+
dummy = self.builder.create_block()
|
|
821
|
+
self.builder.set_insertion_point_to_start(dummy)
|
|
822
|
+
self.scf_stack.append(node)
|
|
823
|
+
self.visit_compound_statement(node.body)
|
|
824
|
+
self.scf_stack.pop()
|
|
825
|
+
loop_defs = self.local_defs
|
|
826
|
+
dummy.erase()
|
|
827
|
+
|
|
828
|
+
# collect loop-carried values
|
|
829
|
+
names = []
|
|
830
|
+
ret_types = []
|
|
831
|
+
init_args = []
|
|
832
|
+
for name in loop_defs:
|
|
833
|
+
if name in liveins:
|
|
834
|
+
# We should not def new constexpr
|
|
835
|
+
loop_val = loop_defs[name]
|
|
836
|
+
live_val = liveins[name]
|
|
837
|
+
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
838
|
+
|
|
839
|
+
# these are loop-carried values
|
|
840
|
+
names.append(name)
|
|
841
|
+
ret_types.append(loop_val.type)
|
|
842
|
+
init_args.append(live_val)
|
|
843
|
+
|
|
844
|
+
self._set_insertion_point_and_loc(ip, last_loc)
|
|
845
|
+
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
|
846
|
+
[arg.handle for arg in init_args])
|
|
847
|
+
# merge the condition region
|
|
848
|
+
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
|
849
|
+
[ty.to_ir(self.builder) for ty in ret_types])
|
|
850
|
+
self.builder.set_insertion_point_to_start(before_block)
|
|
851
|
+
for i, name in enumerate(names):
|
|
852
|
+
self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i])
|
|
853
|
+
self.local_defs[name] = self.lscope[name]
|
|
854
|
+
cond = self.visit(node.test)
|
|
855
|
+
self.builder.set_insertion_point_to_end(before_block)
|
|
856
|
+
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
|
857
|
+
self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
|
858
|
+
# merge the loop body
|
|
859
|
+
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
|
860
|
+
[ty.to_ir(self.builder) for ty in ret_types])
|
|
861
|
+
|
|
862
|
+
# generate loop body
|
|
863
|
+
self.builder.set_insertion_point_to_start(after_block)
|
|
864
|
+
for i, name in enumerate(names):
|
|
865
|
+
self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i])
|
|
866
|
+
self.local_defs[name] = self.lscope[name]
|
|
867
|
+
self.scf_stack.append(node)
|
|
868
|
+
self.visit_compound_statement(node.body)
|
|
869
|
+
self.scf_stack.pop()
|
|
870
|
+
loop_defs = self.local_defs
|
|
871
|
+
yields = []
|
|
872
|
+
for name in loop_defs:
|
|
873
|
+
if name in liveins:
|
|
874
|
+
yields.append(loop_defs[name])
|
|
875
|
+
self.builder.create_yield_op([y.handle for y in yields])
|
|
876
|
+
|
|
877
|
+
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
|
878
|
+
for i, name in enumerate(names):
|
|
879
|
+
new_def = language.core.tensor(while_op.get_result(i), ret_types[i])
|
|
880
|
+
self.lscope[name] = new_def
|
|
881
|
+
self.local_defs[name] = new_def
|
|
882
|
+
|
|
883
|
+
for stmt in node.orelse:
|
|
884
|
+
assert False, "Not implemented"
|
|
885
|
+
ast.NodeVisitor.generic_visit(self, stmt)
|
|
886
|
+
|
|
887
|
+
def visit_Subscript(self, node):
|
|
888
|
+
assert node.ctx.__class__.__name__ == "Load"
|
|
889
|
+
lhs = self.visit(node.value)
|
|
890
|
+
slices = self.visit(node.slice)
|
|
891
|
+
if _is_triton_tensor(lhs):
|
|
892
|
+
return lhs.__getitem__(slices, _builder=self.builder)
|
|
893
|
+
return lhs[slices]
|
|
894
|
+
|
|
895
|
+
def visit_ExtSlice(self, node):
|
|
896
|
+
return [self.visit(dim) for dim in node.dims]
|
|
897
|
+
|
|
898
|
+
def visit_For(self, node):
|
|
899
|
+
IteratorClass = self.visit(node.iter.func)
|
|
900
|
+
iter_args = [self.visit(arg) for arg in node.iter.args]
|
|
901
|
+
iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords)
|
|
902
|
+
if IteratorClass == language.static_range:
|
|
903
|
+
iterator = IteratorClass(*iter_args, **iter_kwargs)
|
|
904
|
+
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
|
|
905
|
+
for i in static_range:
|
|
906
|
+
self.lscope[node.target.id] = constexpr(i)
|
|
907
|
+
self.visit_compound_statement(node.body)
|
|
908
|
+
for stmt in node.orelse:
|
|
909
|
+
ast.NodeVisitor.generic_visit(self, stmt)
|
|
910
|
+
return
|
|
911
|
+
num_stages = None
|
|
912
|
+
loop_unroll_factor = None
|
|
913
|
+
if IteratorClass is language.range:
|
|
914
|
+
iterator = IteratorClass(*iter_args, **iter_kwargs)
|
|
915
|
+
# visit iterator arguments
|
|
916
|
+
# note: only `range` iterator is supported now
|
|
917
|
+
# collect lower bound (lb), upper bound (ub), and step
|
|
918
|
+
lb = iterator.start
|
|
919
|
+
ub = iterator.end
|
|
920
|
+
step = iterator.step
|
|
921
|
+
num_stages = iterator.num_stages
|
|
922
|
+
loop_unroll_factor = iterator.loop_unroll_factor
|
|
923
|
+
elif IteratorClass is range:
|
|
924
|
+
# visit iterator arguments
|
|
925
|
+
# note: only `range` iterator is supported now
|
|
926
|
+
# collect lower bound (lb), upper bound (ub), and step
|
|
927
|
+
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
|
|
928
|
+
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
|
|
929
|
+
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
|
|
930
|
+
else:
|
|
931
|
+
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
|
|
932
|
+
# handle negative constant step (not supported by scf.for in MLIR)
|
|
933
|
+
negative_step = False
|
|
934
|
+
if _is_constexpr(step) and step.value < 0:
|
|
935
|
+
step = constexpr(-step.value)
|
|
936
|
+
negative_step = True
|
|
937
|
+
lb, ub = ub, lb
|
|
938
|
+
lb = language.semantic.to_tensor(lb, self.builder)
|
|
939
|
+
ub = language.semantic.to_tensor(ub, self.builder)
|
|
940
|
+
step = language.semantic.to_tensor(step, self.builder)
|
|
941
|
+
# induction variable type
|
|
942
|
+
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
|
|
943
|
+
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
|
|
944
|
+
iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
|
945
|
+
iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype)
|
|
946
|
+
iv_ir_type = iv_type.to_ir(self.builder)
|
|
947
|
+
iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
|
|
948
|
+
# lb/ub/step might be constexpr, we need to cast them to tensor
|
|
949
|
+
lb = lb.handle
|
|
950
|
+
ub = ub.handle
|
|
951
|
+
step = step.handle
|
|
952
|
+
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
|
|
953
|
+
lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
|
|
954
|
+
ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
|
|
955
|
+
step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
|
|
956
|
+
# Create placeholder for the loop induction variable
|
|
957
|
+
iv = self.builder.create_poison(iv_ir_type)
|
|
958
|
+
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
|
|
959
|
+
|
|
960
|
+
with enter_sub_region(self) as sr:
|
|
961
|
+
liveins, insert_block = sr
|
|
962
|
+
ip, last_loc = self._get_insertion_point_and_loc()
|
|
963
|
+
|
|
964
|
+
# create loop body block
|
|
965
|
+
block = self.builder.create_block()
|
|
966
|
+
self.builder.set_insertion_point_to_start(block)
|
|
967
|
+
# dry visit loop body
|
|
968
|
+
self.scf_stack.append(node)
|
|
969
|
+
self.visit_compound_statement(node.body)
|
|
970
|
+
self.scf_stack.pop()
|
|
971
|
+
block.erase()
|
|
972
|
+
|
|
973
|
+
# If a variable (name) is defined in both its parent & itself, then it's
|
|
974
|
+
# a loop-carried variable. (They must be of the same type)
|
|
975
|
+
init_args = []
|
|
976
|
+
yields = []
|
|
977
|
+
names = []
|
|
978
|
+
for name in self.local_defs:
|
|
979
|
+
if name in liveins:
|
|
980
|
+
loop_val = self.local_defs[name]
|
|
981
|
+
live_val = liveins[name]
|
|
982
|
+
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
983
|
+
|
|
984
|
+
names.append(name)
|
|
985
|
+
init_args.append(live_val)
|
|
986
|
+
yields.append(loop_val)
|
|
987
|
+
|
|
988
|
+
# create ForOp
|
|
989
|
+
self._set_insertion_point_and_loc(ip, last_loc)
|
|
990
|
+
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
|
991
|
+
if num_stages is not None:
|
|
992
|
+
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
|
|
993
|
+
if loop_unroll_factor is not None:
|
|
994
|
+
for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
|
|
995
|
+
|
|
996
|
+
self.scf_stack.append(node)
|
|
997
|
+
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
|
998
|
+
# reset local scope to not pick up local defs from the previous dry run.
|
|
999
|
+
self.lscope = liveins.copy()
|
|
1000
|
+
self.local_defs = {}
|
|
1001
|
+
for i, name in enumerate(names):
|
|
1002
|
+
self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type))
|
|
1003
|
+
self.visit_compound_statement(node.body)
|
|
1004
|
+
self.scf_stack.pop()
|
|
1005
|
+
yields = []
|
|
1006
|
+
for name in self.local_defs:
|
|
1007
|
+
if name in liveins:
|
|
1008
|
+
yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder))
|
|
1009
|
+
|
|
1010
|
+
# create YieldOp
|
|
1011
|
+
if len(yields) > 0:
|
|
1012
|
+
self.builder.create_yield_op([y.handle for y in yields])
|
|
1013
|
+
for_op_region = for_op.get_body(0).get_parent()
|
|
1014
|
+
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
|
1015
|
+
|
|
1016
|
+
# update induction variable with actual value, and replace all uses
|
|
1017
|
+
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
|
1018
|
+
iv = for_op.get_induction_var()
|
|
1019
|
+
if negative_step:
|
|
1020
|
+
iv = self.builder.create_sub(ub, iv)
|
|
1021
|
+
iv = self.builder.create_add(iv, lb)
|
|
1022
|
+
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
|
1023
|
+
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
|
|
1024
|
+
|
|
1025
|
+
# update lscope & local_defs (ForOp defines new values)
|
|
1026
|
+
for i, name in enumerate(names):
|
|
1027
|
+
self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type))
|
|
1028
|
+
|
|
1029
|
+
for stmt in node.orelse:
|
|
1030
|
+
assert False, "Don't know what to do with else after for"
|
|
1031
|
+
ast.NodeVisitor.generic_visit(self, stmt)
|
|
1032
|
+
|
|
1033
|
+
def visit_Slice(self, node):
|
|
1034
|
+
lower = self.visit(node.lower)
|
|
1035
|
+
upper = self.visit(node.upper)
|
|
1036
|
+
step = self.visit(node.step)
|
|
1037
|
+
return slice(lower, upper, step)
|
|
1038
|
+
|
|
1039
|
+
def visit_Index(self, node):
|
|
1040
|
+
return self.visit(node.value)
|
|
1041
|
+
|
|
1042
|
+
def visit_keyword(self, node) -> Tuple[str, Any]:
|
|
1043
|
+
return node.arg, self.visit(node.value)
|
|
1044
|
+
|
|
1045
|
+
def visit_Assert(self, node) -> Any:
|
|
1046
|
+
test = self.visit(node.test)
|
|
1047
|
+
msg = self.visit(node.msg) if node.msg is not None else ""
|
|
1048
|
+
return language.core.device_assert(test, msg, _builder=self.builder)
|
|
1049
|
+
|
|
1050
|
+
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
|
1051
|
+
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
|
1052
|
+
args = [args[name] for name in fn.arg_names]
|
|
1053
|
+
args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args]
|
|
1054
|
+
# generate function def
|
|
1055
|
+
attributes = {}
|
|
1056
|
+
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
|
|
1057
|
+
constants = {i: args[i] for i in constexprs}
|
|
1058
|
+
# generate call
|
|
1059
|
+
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
|
1060
|
+
arg_vals = [arg.handle for arg in args if arg is not None]
|
|
1061
|
+
arg_types = [arg.type for arg in args if arg is not None]
|
|
1062
|
+
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
|
1063
|
+
# generate function def if necessary
|
|
1064
|
+
if not self.module.has_function(fn_name):
|
|
1065
|
+
prototype = language.function_type([], arg_types)
|
|
1066
|
+
gscope = fn.__globals__
|
|
1067
|
+
# If the callee is not set, we use the same debug setting as the caller
|
|
1068
|
+
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1069
|
+
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
|
|
1070
|
+
jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
|
|
1071
|
+
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
|
1072
|
+
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
|
|
1073
|
+
module_map=self.builder.module_map)
|
|
1074
|
+
try:
|
|
1075
|
+
generator.visit(fn.parse())
|
|
1076
|
+
except Exception as e:
|
|
1077
|
+
# Wrap the error in the callee with the location of the call.
|
|
1078
|
+
raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
|
|
1079
|
+
|
|
1080
|
+
callee_ret_type = generator.ret_type
|
|
1081
|
+
self.function_ret_types[fn_name] = callee_ret_type
|
|
1082
|
+
else:
|
|
1083
|
+
callee_ret_type = self.function_ret_types[fn_name]
|
|
1084
|
+
symbol = self.module.get_function(fn_name)
|
|
1085
|
+
call_op = self.builder.call(symbol, arg_vals)
|
|
1086
|
+
if call_op.get_num_results() == 0 or callee_ret_type is None:
|
|
1087
|
+
return None
|
|
1088
|
+
elif call_op.get_num_results() == 1:
|
|
1089
|
+
return tensor(call_op.get_result(0), callee_ret_type)
|
|
1090
|
+
else:
|
|
1091
|
+
# should return a tuple of tl.tensor
|
|
1092
|
+
results = []
|
|
1093
|
+
for i in range(call_op.get_num_results()):
|
|
1094
|
+
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
|
|
1095
|
+
return tuple(results)
|
|
1096
|
+
|
|
1097
|
+
def visit_Call(self, node):
|
|
1098
|
+
fn = _unwrap_if_constexpr(self.visit(node.func))
|
|
1099
|
+
static_implementation = self.statically_implemented_functions.get(fn)
|
|
1100
|
+
if static_implementation is not None:
|
|
1101
|
+
return static_implementation(self, node)
|
|
1102
|
+
|
|
1103
|
+
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
|
1104
|
+
args = [self.visit(arg) for arg in node.args]
|
|
1105
|
+
if isinstance(fn, JITFunction):
|
|
1106
|
+
_check_fn_args(node, fn, args)
|
|
1107
|
+
return self.call_JitFunction(fn, args, kws)
|
|
1108
|
+
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
|
|
1109
|
+
extra_kwargs = {"_builder": self.builder}
|
|
1110
|
+
sig = inspect.signature(fn)
|
|
1111
|
+
if '_generator' in sig.parameters:
|
|
1112
|
+
extra_kwargs['_generator'] = self
|
|
1113
|
+
try:
|
|
1114
|
+
return fn(*args, **extra_kwargs, **kws)
|
|
1115
|
+
except Exception as e:
|
|
1116
|
+
# Normally when we raise a CompilationError, we raise it as
|
|
1117
|
+
# `from None`, because the original fileline from the exception
|
|
1118
|
+
# is not relevant (and often points into code_generator.py
|
|
1119
|
+
# itself). But when calling a function, we raise as `from e` to
|
|
1120
|
+
# preserve the traceback of the original error, which may e.g.
|
|
1121
|
+
# be in core.py.
|
|
1122
|
+
raise CompilationError(self.jit_fn.src, node, None) from e
|
|
1123
|
+
|
|
1124
|
+
if fn in self.builtin_namespace.values():
|
|
1125
|
+
args = map(_unwrap_if_constexpr, args)
|
|
1126
|
+
return fn(*args, **kws)
|
|
1127
|
+
|
|
1128
|
+
def visit_Constant(self, node):
|
|
1129
|
+
return constexpr(node.value)
|
|
1130
|
+
|
|
1131
|
+
def visit_BoolOp(self, node: ast.BoolOp):
|
|
1132
|
+
if len(node.values) != 2:
|
|
1133
|
+
raise self._unsupported(
|
|
1134
|
+
node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
|
1135
|
+
lhs = self.visit(node.values[0])
|
|
1136
|
+
rhs = self.visit(node.values[1])
|
|
1137
|
+
method_name = self._method_name_for_bool_op.get(type(node.op))
|
|
1138
|
+
if method_name is None:
|
|
1139
|
+
raise self._unsupported(
|
|
1140
|
+
node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
|
1141
|
+
return self._apply_binary_method(method_name, lhs, rhs)
|
|
1142
|
+
|
|
1143
|
+
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
|
1144
|
+
|
|
1145
|
+
if sys.version_info < (3, 8):
|
|
1146
|
+
|
|
1147
|
+
def visit_NameConstant(self, node):
|
|
1148
|
+
return constexpr(node.value)
|
|
1149
|
+
|
|
1150
|
+
def visit_Num(self, node):
|
|
1151
|
+
return constexpr(node.n)
|
|
1152
|
+
|
|
1153
|
+
def visit_Str(self, node):
|
|
1154
|
+
return constexpr(ast.literal_eval(node))
|
|
1155
|
+
|
|
1156
|
+
def visit_Attribute(self, node):
|
|
1157
|
+
lhs = self.visit(node.value)
|
|
1158
|
+
if _is_triton_tensor(lhs) and node.attr == "T":
|
|
1159
|
+
return language.semantic.permute(lhs, (1, 0), builder=self.builder)
|
|
1160
|
+
return getattr(lhs, node.attr)
|
|
1161
|
+
|
|
1162
|
+
def visit_Expr(self, node):
|
|
1163
|
+
ast.NodeVisitor.generic_visit(self, node)
|
|
1164
|
+
|
|
1165
|
+
def visit_NoneType(self, node):
|
|
1166
|
+
return None
|
|
1167
|
+
|
|
1168
|
+
def visit_JoinedStr(self, node):
|
|
1169
|
+
values = list(node.values)
|
|
1170
|
+
for i, value in enumerate(values):
|
|
1171
|
+
if isinstance(value, ast.Constant):
|
|
1172
|
+
values[i] = str(value.value)
|
|
1173
|
+
elif isinstance(value, ast.FormattedValue):
|
|
1174
|
+
conversion_code = value.conversion
|
|
1175
|
+
evaluated = self.visit(value.value)
|
|
1176
|
+
if not _is_constexpr(evaluated):
|
|
1177
|
+
raise self._unsupported(
|
|
1178
|
+
node,
|
|
1179
|
+
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
|
|
1180
|
+
+ str(type(evaluated)))
|
|
1181
|
+
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
|
|
1182
|
+
else:
|
|
1183
|
+
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
|
|
1184
|
+
return ''.join(values)
|
|
1185
|
+
|
|
1186
|
+
def visit(self, node):
|
|
1187
|
+
if node is None:
|
|
1188
|
+
return
|
|
1189
|
+
with warnings.catch_warnings():
|
|
1190
|
+
# The ast library added visit_Constant and deprecated some other
|
|
1191
|
+
# methods but we can't move to that without breaking Python 3.6 and 3.7.
|
|
1192
|
+
warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
|
|
1193
|
+
warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
|
|
1194
|
+
last_node = self.cur_node
|
|
1195
|
+
last_loc = self.builder.get_loc()
|
|
1196
|
+
self.cur_node = node
|
|
1197
|
+
if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
|
|
1198
|
+
self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
|
|
1199
|
+
last_loc = self.builder.get_loc()
|
|
1200
|
+
try:
|
|
1201
|
+
ret = super().visit(node)
|
|
1202
|
+
except CompilationError:
|
|
1203
|
+
raise
|
|
1204
|
+
except Exception as e:
|
|
1205
|
+
# Wrap the error in a CompilationError which contains the source
|
|
1206
|
+
# of the @jit function.
|
|
1207
|
+
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
|
|
1208
|
+
|
|
1209
|
+
# Reset the location to the last one before the visit
|
|
1210
|
+
if last_loc:
|
|
1211
|
+
self.cur_node = last_node
|
|
1212
|
+
self.builder.set_loc(last_loc)
|
|
1213
|
+
return ret
|
|
1214
|
+
|
|
1215
|
+
def generic_visit(self, node):
|
|
1216
|
+
raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__))
|
|
1217
|
+
|
|
1218
|
+
def execute_static_assert(self, node: ast.Call) -> None:
|
|
1219
|
+
arg_count = len(node.args)
|
|
1220
|
+
if not (0 < arg_count <= 2) or len(node.keywords):
|
|
1221
|
+
raise TypeError("`static_assert` requires one or two positional arguments only")
|
|
1222
|
+
|
|
1223
|
+
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
|
|
1224
|
+
if not isinstance(passed, bool):
|
|
1225
|
+
raise NotImplementedError(
|
|
1226
|
+
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
|
|
1227
|
+
)
|
|
1228
|
+
if not passed:
|
|
1229
|
+
if arg_count == 1:
|
|
1230
|
+
message = ""
|
|
1231
|
+
else:
|
|
1232
|
+
try:
|
|
1233
|
+
message = self.visit(node.args[1])
|
|
1234
|
+
except Exception as e:
|
|
1235
|
+
message = "<failed to evaluate assertion message: " + repr(e) + ">"
|
|
1236
|
+
|
|
1237
|
+
raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message))
|
|
1238
|
+
return None
|
|
1239
|
+
|
|
1240
|
+
def static_executor(python_fn):
|
|
1241
|
+
|
|
1242
|
+
def ret(self, node: ast.Call):
|
|
1243
|
+
kws = {
|
|
1244
|
+
name: _unwrap_if_constexpr(value)
|
|
1245
|
+
for name, value in (self.visit(keyword) for keyword in node.keywords)
|
|
1246
|
+
}
|
|
1247
|
+
args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
|
|
1248
|
+
return constexpr(python_fn(*args, **kws))
|
|
1249
|
+
|
|
1250
|
+
return ret
|
|
1251
|
+
|
|
1252
|
+
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
|
|
1253
|
+
language.core.static_assert: execute_static_assert,
|
|
1254
|
+
language.core.static_print: static_executor(print),
|
|
1255
|
+
int: static_executor(int),
|
|
1256
|
+
len: static_executor(len),
|
|
1257
|
+
}
|
|
1258
|
+
|
|
1259
|
+
|
|
1260
|
+
def kernel_suffix(signature, specialization):
|
|
1261
|
+
# suffix format:
|
|
1262
|
+
# <argid><'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8>
|
|
1263
|
+
suffix = ''
|
|
1264
|
+
for i, _ in enumerate(signature):
|
|
1265
|
+
suffix += str(i)
|
|
1266
|
+
if i in specialization.equal_to_1:
|
|
1267
|
+
suffix += 'c'
|
|
1268
|
+
if i in specialization.divisibility_16:
|
|
1269
|
+
suffix += 'd'
|
|
1270
|
+
return suffix
|
|
1271
|
+
|
|
1272
|
+
|
|
1273
|
+
def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
|
|
1274
|
+
attrs = specialization.attrs
|
|
1275
|
+
# create kernel prototype
|
|
1276
|
+
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
|
|
1277
|
+
constants = {cst_key(key): value for key, value in specialization.constants.items()}
|
|
1278
|
+
# visit kernel AST
|
|
1279
|
+
gscope = fn.__globals__.copy()
|
|
1280
|
+
function_name = fn.repr(specialization)
|
|
1281
|
+
tys = list(specialization.signature.values())
|
|
1282
|
+
new_constants = attrs.get_constants()
|
|
1283
|
+
for k in new_constants:
|
|
1284
|
+
if k in tys and tys[k] == "i1" and new_constants[k] == 1:
|
|
1285
|
+
new_constants[k] = True
|
|
1286
|
+
|
|
1287
|
+
new_attrs = attrs.filter_out_constants()
|
|
1288
|
+
fn_attrs = new_attrs.get_fn_attrs()
|
|
1289
|
+
all_constants = constants.copy()
|
|
1290
|
+
all_constants.update(new_constants)
|
|
1291
|
+
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
|
|
1292
|
+
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1293
|
+
|
|
1294
|
+
prototype = language.function_type([], arg_types)
|
|
1295
|
+
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
|
|
1296
|
+
jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name,
|
|
1297
|
+
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
|
|
1298
|
+
generator.visit(fn.parse())
|
|
1299
|
+
|
|
1300
|
+
ret = generator.module
|
|
1301
|
+
# module takes ownership of the context
|
|
1302
|
+
ret.context = context
|
|
1303
|
+
return ret
|