triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
|
@@ -1,20 +1,22 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import builtins
|
|
3
|
+
import contextlib
|
|
4
|
+
import copy
|
|
2
5
|
import inspect
|
|
3
6
|
import re
|
|
4
7
|
import warnings
|
|
5
|
-
import os
|
|
6
8
|
import textwrap
|
|
7
9
|
import itertools
|
|
10
|
+
from dataclasses import dataclass
|
|
8
11
|
from types import ModuleType
|
|
9
12
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
|
|
10
13
|
|
|
11
|
-
from .. import language
|
|
12
|
-
from .._C.libtriton import ir
|
|
13
|
-
from ..language import constexpr,
|
|
14
|
-
from ..language.core import _unwrap_if_constexpr,
|
|
15
|
-
from ..runtime.jit import get_jit_fn_file_line
|
|
14
|
+
from .. import knobs, language
|
|
15
|
+
from .._C.libtriton import ir, gluon_ir
|
|
16
|
+
from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
|
|
17
|
+
from ..language.core import _unwrap_if_constexpr, base_value, base_type
|
|
16
18
|
# ideally we wouldn't need any runtime component
|
|
17
|
-
from ..runtime import JITFunction
|
|
19
|
+
from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
|
|
18
20
|
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
|
|
19
21
|
|
|
20
22
|
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
|
@@ -27,35 +29,17 @@ def check_identifier_legality(name, type):
|
|
|
27
29
|
return name
|
|
28
30
|
|
|
29
31
|
|
|
30
|
-
def
|
|
31
|
-
if ty.is_tuple():
|
|
32
|
-
return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
|
|
33
|
-
if ty.is_ptr():
|
|
34
|
-
return 'P' + mangle_ty(ty.element_ty)
|
|
35
|
-
if ty.is_int():
|
|
36
|
-
SIGNED = language.dtype.SIGNEDNESS.SIGNED
|
|
37
|
-
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
|
|
38
|
-
return prefix + str(ty.int_bitwidth)
|
|
39
|
-
if ty.is_floating():
|
|
40
|
-
return str(ty)
|
|
41
|
-
if ty.is_block():
|
|
42
|
-
elt = mangle_ty(ty.scalar)
|
|
43
|
-
shape = '_'.join(map(str, ty.shape))
|
|
44
|
-
return f'{elt}S{shape}S'
|
|
45
|
-
if ty.is_void():
|
|
46
|
-
return 'V'
|
|
47
|
-
raise TypeError(f'Unsupported type {ty}')
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def mangle_fn(name, arg_tys, constants):
|
|
32
|
+
def mangle_fn(name, arg_tys, constants, caller_context):
|
|
51
33
|
# doesn't mangle ret type, which must be a function of arg tys
|
|
52
|
-
mangled_arg_names = '_'.join([
|
|
34
|
+
mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
|
|
53
35
|
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
|
|
54
36
|
mangled_constants = mangled_constants.replace('.', '_d_')
|
|
55
37
|
mangled_constants = mangled_constants.replace("'", '_sq_')
|
|
56
38
|
# [ and ] are not allowed in LLVM identifiers
|
|
57
39
|
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
|
|
58
40
|
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
|
41
|
+
if caller_context is not None:
|
|
42
|
+
ret += caller_context.mangle()
|
|
59
43
|
return ret
|
|
60
44
|
|
|
61
45
|
|
|
@@ -68,11 +52,11 @@ def _is_triton_tensor(o: Any) -> bool:
|
|
|
68
52
|
|
|
69
53
|
|
|
70
54
|
def _is_constexpr(o: Any) -> bool:
|
|
71
|
-
return o is None or isinstance(o, (constexpr, language.core.dtype))
|
|
55
|
+
return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
|
|
72
56
|
|
|
73
57
|
|
|
74
|
-
def
|
|
75
|
-
return _is_triton_tensor(o) and (
|
|
58
|
+
def _is_non_scalar_tensor(o: Any) -> bool:
|
|
59
|
+
return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
|
|
76
60
|
|
|
77
61
|
|
|
78
62
|
def _is_list_like(o: Any) -> bool:
|
|
@@ -82,7 +66,7 @@ def _is_list_like(o: Any) -> bool:
|
|
|
82
66
|
def _check_fn_args(node, fn, args):
|
|
83
67
|
if fn.noinline:
|
|
84
68
|
for idx, arg in enumerate(args):
|
|
85
|
-
if not _is_constexpr(arg) and
|
|
69
|
+
if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
|
|
86
70
|
raise UnsupportedLanguageConstruct(
|
|
87
71
|
fn.src, node,
|
|
88
72
|
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
|
|
@@ -102,6 +86,7 @@ def _apply_to_tuple_values(value, fn):
|
|
|
102
86
|
assert False, f"Unsupported type {type(value)}"
|
|
103
87
|
|
|
104
88
|
vals = [fn(v) for v in value]
|
|
89
|
+
vals = [constexpr(v) if v is None else v for v in vals]
|
|
105
90
|
types = [v.type for v in vals]
|
|
106
91
|
return language.tuple(vals, language.tuple_type(types, fields))
|
|
107
92
|
|
|
@@ -124,6 +109,17 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
|
|
|
124
109
|
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
|
|
125
110
|
|
|
126
111
|
|
|
112
|
+
def _clone_triton_value(val):
|
|
113
|
+
handles = []
|
|
114
|
+
val._flatten_ir(handles)
|
|
115
|
+
clone, _ = val.type._unflatten_ir(handles, 0)
|
|
116
|
+
return clone
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _clone_scope(scope):
|
|
120
|
+
return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
|
|
121
|
+
|
|
122
|
+
|
|
127
123
|
class enter_sub_region:
|
|
128
124
|
|
|
129
125
|
def __init__(self, generator):
|
|
@@ -131,8 +127,8 @@ class enter_sub_region:
|
|
|
131
127
|
|
|
132
128
|
def __enter__(self):
|
|
133
129
|
# record lscope & local_defs in the parent scope
|
|
134
|
-
self.liveins = self.generator.lscope
|
|
135
|
-
self.prev_defs = self.generator.local_defs
|
|
130
|
+
self.liveins = _clone_scope(self.generator.lscope)
|
|
131
|
+
self.prev_defs = _clone_scope(self.generator.local_defs)
|
|
136
132
|
self.generator.local_defs = {}
|
|
137
133
|
self.insert_block = self.generator.builder.get_insertion_block()
|
|
138
134
|
self.insert_point = self.generator.builder.get_insertion_point()
|
|
@@ -154,10 +150,9 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
|
|
154
150
|
return any(self.visit(s) for s in body)
|
|
155
151
|
|
|
156
152
|
def _visit_function(self, fn) -> bool:
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
return ContainsReturnChecker(self.gscope).visit(fn_node)
|
|
153
|
+
# No need to check within the function as it won't cause an early return.
|
|
154
|
+
# If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
|
|
155
|
+
# we should check for this and emit a warning.
|
|
161
156
|
return False
|
|
162
157
|
|
|
163
158
|
def generic_visit(self, node) -> bool:
|
|
@@ -241,26 +236,26 @@ class ASTFunction:
|
|
|
241
236
|
self.constants = constants
|
|
242
237
|
self.attrs = attrs
|
|
243
238
|
|
|
244
|
-
def
|
|
245
|
-
|
|
246
|
-
for
|
|
247
|
-
if
|
|
239
|
+
def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
|
|
240
|
+
ir_types = []
|
|
241
|
+
for ty in types:
|
|
242
|
+
if ty is None:
|
|
248
243
|
continue
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
return ret_types
|
|
244
|
+
ty._flatten_ir_types(builder, ir_types)
|
|
245
|
+
return ir_types
|
|
246
|
+
|
|
247
|
+
def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
|
|
248
|
+
return self.flatten_ir_types(builder, self.ret_types)
|
|
255
249
|
|
|
256
250
|
def serialize(self, builder: ir.builder):
|
|
257
251
|
# fill up IR values in template
|
|
258
252
|
# > build function
|
|
259
253
|
is_val = lambda path, _: path not in self.constants and _ is not None
|
|
260
254
|
val_paths = list(find_paths_if(self.arg_types, is_val))
|
|
261
|
-
arg_types = [get_iterable_path(self.arg_types, path)
|
|
262
|
-
|
|
263
|
-
|
|
255
|
+
arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
|
|
256
|
+
arg_types_ir = self.flatten_ir_types(builder, arg_types)
|
|
257
|
+
ret_types_ir = self.return_types_ir(builder)
|
|
258
|
+
return builder.get_function_ty(arg_types_ir, ret_types_ir)
|
|
264
259
|
|
|
265
260
|
def deserialize(self, fn):
|
|
266
261
|
# create "template"
|
|
@@ -272,19 +267,18 @@ class ASTFunction:
|
|
|
272
267
|
vals = make_template(self.arg_types)
|
|
273
268
|
is_val = lambda path, _: path not in self.constants and _ is not None
|
|
274
269
|
val_paths = list(find_paths_if(self.arg_types, is_val))
|
|
275
|
-
# > set attributes
|
|
276
|
-
for attr_path, attr_specs in self.attrs.items():
|
|
277
|
-
for attr_name, attr_val in attr_specs:
|
|
278
|
-
if attr_path in val_paths:
|
|
279
|
-
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
|
|
280
|
-
for i, path in enumerate(val_paths):
|
|
281
|
-
ty = get_iterable_path(self.arg_types, path)
|
|
282
|
-
if isinstance(ty, nv_tma_desc_type):
|
|
283
|
-
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
|
|
284
270
|
# > add IR values to the template
|
|
285
|
-
|
|
271
|
+
cursor = 0
|
|
272
|
+
handles = [fn.args(i) for i in range(fn.get_num_args())]
|
|
273
|
+
for path in val_paths:
|
|
286
274
|
ty = get_iterable_path(self.arg_types, path)
|
|
287
|
-
|
|
275
|
+
# > set attributes
|
|
276
|
+
attr_specs = self.attrs.get(path, [])
|
|
277
|
+
for attr_name, attr_val in attr_specs:
|
|
278
|
+
fn.set_arg_attr(cursor, attr_name, attr_val)
|
|
279
|
+
# > build frontend value
|
|
280
|
+
val, cursor = ty._unflatten_ir(handles, cursor)
|
|
281
|
+
set_iterable_path(vals, path, val)
|
|
288
282
|
# > add constexpr values to the template
|
|
289
283
|
constants = self.constants
|
|
290
284
|
for path, val in constants.items():
|
|
@@ -292,13 +286,29 @@ class ASTFunction:
|
|
|
292
286
|
return vals
|
|
293
287
|
|
|
294
288
|
|
|
289
|
+
@dataclass(frozen=True)
|
|
290
|
+
class BoundJITMethod:
|
|
291
|
+
__self__: base_value
|
|
292
|
+
__func__: JITFunction
|
|
293
|
+
|
|
294
|
+
|
|
295
295
|
class CodeGenerator(ast.NodeVisitor):
|
|
296
296
|
|
|
297
|
-
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns,
|
|
298
|
-
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
|
299
|
-
file_name: Optional[str] = None, begin_line=0):
|
|
297
|
+
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
|
|
298
|
+
module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
|
299
|
+
noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
|
|
300
300
|
self.context = context
|
|
301
|
-
self.
|
|
301
|
+
self.is_gluon = is_gluon
|
|
302
|
+
if is_gluon:
|
|
303
|
+
from triton.experimental.gluon.language._semantic import GluonSemantic
|
|
304
|
+
self.builder = gluon_ir.GluonOpBuilder(context)
|
|
305
|
+
self.semantic = GluonSemantic(self.builder)
|
|
306
|
+
else:
|
|
307
|
+
from triton.language.semantic import TritonSemantic
|
|
308
|
+
self.builder = ir.builder(context)
|
|
309
|
+
self.semantic = TritonSemantic(self.builder)
|
|
310
|
+
|
|
311
|
+
self.name_loc_as_prefix = None
|
|
302
312
|
self.file_name = file_name
|
|
303
313
|
# node.lineno starts from 1, so we need to subtract 1
|
|
304
314
|
self.begin_line = begin_line - 1
|
|
@@ -306,7 +316,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
306
316
|
self.builder.options = options
|
|
307
317
|
# dict of functions provided by the backend. Below are the list of possible functions:
|
|
308
318
|
# Convert custom types not natively supported on HW.
|
|
309
|
-
# convert_custom_types(
|
|
319
|
+
# convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None)
|
|
310
320
|
self.builder.codegen_fns = codegen_fns
|
|
311
321
|
self.builder.module_map = {} if module_map is None else module_map
|
|
312
322
|
self.module = self.builder.create_module() if module is None else module
|
|
@@ -329,11 +339,13 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
329
339
|
self.jit_fn = jit_fn
|
|
330
340
|
# TODO: we currently generate illegal names for non-kernel functions involving constexprs!
|
|
331
341
|
if is_kernel:
|
|
342
|
+
function_name = function_name[function_name.rfind('.') + 1:]
|
|
332
343
|
function_name = check_identifier_legality(function_name, "function")
|
|
333
344
|
self.function_name = function_name
|
|
334
345
|
self.is_kernel = is_kernel
|
|
335
346
|
self.cur_node = None
|
|
336
347
|
self.noinline = noinline
|
|
348
|
+
self.caller_context = caller_context
|
|
337
349
|
self.scf_stack = []
|
|
338
350
|
self.ret_type = None
|
|
339
351
|
# SSA-construction
|
|
@@ -345,7 +357,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
345
357
|
# special handling.
|
|
346
358
|
self.visiting_arg_default_value = False
|
|
347
359
|
|
|
348
|
-
builtin_namespace: Dict[str, Any] = {
|
|
360
|
+
builtin_namespace: Dict[str, Any] = {
|
|
361
|
+
_.__name__: _
|
|
362
|
+
for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
|
|
363
|
+
}
|
|
349
364
|
builtin_namespace.update((
|
|
350
365
|
('print', language.core.device_print),
|
|
351
366
|
('min', language.minimum),
|
|
@@ -378,11 +393,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
378
393
|
# But actually a bunch of other things, such as module imports, are
|
|
379
394
|
# technically Python globals. We have to allow these too!
|
|
380
395
|
if any([
|
|
381
|
-
val is absent,
|
|
396
|
+
val is absent,
|
|
397
|
+
name in self.builtin_namespace, #
|
|
382
398
|
type(val) is ModuleType, #
|
|
383
|
-
isinstance(val,
|
|
399
|
+
isinstance(val, JITCallable), #
|
|
384
400
|
getattr(val, "__triton_builtin__", False), #
|
|
401
|
+
getattr(val, "__triton_aggregate__", False), #
|
|
385
402
|
getattr(val, "__module__", "").startswith("triton.language"), #
|
|
403
|
+
getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), #
|
|
386
404
|
isinstance(val, language.dtype), #
|
|
387
405
|
_is_namedtuple(val),
|
|
388
406
|
self._is_constexpr_global(name), #
|
|
@@ -390,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
390
408
|
# because you should be able to do
|
|
391
409
|
# @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
|
|
392
410
|
self.visiting_arg_default_value, #
|
|
393
|
-
|
|
411
|
+
knobs.compilation.allow_non_constexpr_globals,
|
|
394
412
|
]):
|
|
395
413
|
return val
|
|
396
414
|
raise NameError(
|
|
@@ -414,6 +432,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
414
432
|
|
|
415
433
|
return name_lookup
|
|
416
434
|
|
|
435
|
+
@contextlib.contextmanager
|
|
436
|
+
def _name_loc_prefix(self, prefix):
|
|
437
|
+
self.name_loc_as_prefix = prefix
|
|
438
|
+
yield
|
|
439
|
+
self.name_loc_as_prefix = None
|
|
440
|
+
|
|
441
|
+
def _maybe_set_loc_to_name(self, val, name):
|
|
442
|
+
if isinstance(val, (ir.value, ir.block_argument)):
|
|
443
|
+
val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
|
|
444
|
+
elif _is_triton_value(val):
|
|
445
|
+
handles = []
|
|
446
|
+
val._flatten_ir(handles)
|
|
447
|
+
for handle in handles:
|
|
448
|
+
handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
|
|
449
|
+
|
|
417
450
|
def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
|
|
418
451
|
''' This function:
|
|
419
452
|
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
|
@@ -435,6 +468,43 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
435
468
|
self.builder.restore_insertion_point(ip)
|
|
436
469
|
self.builder.set_loc(loc)
|
|
437
470
|
|
|
471
|
+
def _find_carries(self, node, liveins):
|
|
472
|
+
# create loop body block
|
|
473
|
+
block = self.builder.create_block()
|
|
474
|
+
self.builder.set_insertion_point_to_start(block)
|
|
475
|
+
# dry visit loop body
|
|
476
|
+
self.scf_stack.append(node)
|
|
477
|
+
self.visit_compound_statement(node.body)
|
|
478
|
+
self.scf_stack.pop()
|
|
479
|
+
block.erase()
|
|
480
|
+
|
|
481
|
+
# If a variable (name) has changed value within the loop, then it's
|
|
482
|
+
# a loop-carried variable. (The new and old value must be of the
|
|
483
|
+
# same type)
|
|
484
|
+
init_tys = []
|
|
485
|
+
init_handles = []
|
|
486
|
+
names = []
|
|
487
|
+
|
|
488
|
+
for name, live_val in liveins.items():
|
|
489
|
+
if _is_triton_value(live_val):
|
|
490
|
+
loop_val = self.lscope[name]
|
|
491
|
+
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
492
|
+
|
|
493
|
+
live_handles = flatten_values_to_ir([live_val])
|
|
494
|
+
loop_handles = flatten_values_to_ir([loop_val])
|
|
495
|
+
if live_handles != loop_handles:
|
|
496
|
+
names.append(name)
|
|
497
|
+
init_tys.append(live_val.type)
|
|
498
|
+
init_handles.extend(live_handles)
|
|
499
|
+
else:
|
|
500
|
+
assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
|
|
501
|
+
|
|
502
|
+
# reset local scope to not pick up local defs from the dry run.
|
|
503
|
+
self.lscope = liveins.copy()
|
|
504
|
+
self.local_defs = {}
|
|
505
|
+
|
|
506
|
+
return names, init_handles, init_tys
|
|
507
|
+
|
|
438
508
|
#
|
|
439
509
|
# AST visitor
|
|
440
510
|
#
|
|
@@ -458,6 +528,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
458
528
|
elts = language.tuple([self.visit(elt) for elt in node.elts])
|
|
459
529
|
return elts
|
|
460
530
|
|
|
531
|
+
def visit_ListComp(self, node: ast.ListComp):
|
|
532
|
+
if len(node.generators) != 1:
|
|
533
|
+
raise ValueError("nested comprehensions are not supported")
|
|
534
|
+
|
|
535
|
+
comp = node.generators[0]
|
|
536
|
+
iter = self.visit(comp.iter)
|
|
537
|
+
if not isinstance(iter, tl_tuple):
|
|
538
|
+
raise NotImplementedError("only tuple comprehensions are supported")
|
|
539
|
+
|
|
540
|
+
results = []
|
|
541
|
+
for item in iter:
|
|
542
|
+
self.set_value(comp.target.id, item)
|
|
543
|
+
results.append(self.visit(node.elt))
|
|
544
|
+
return tl_tuple(results)
|
|
545
|
+
|
|
461
546
|
# By design, only non-kernel functions can return
|
|
462
547
|
def visit_Return(self, node):
|
|
463
548
|
ret_value = self.visit(node.value)
|
|
@@ -467,7 +552,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
467
552
|
if isinstance(value, language.tuple):
|
|
468
553
|
return _apply_to_tuple_values(value, decay)
|
|
469
554
|
elif isinstance(value, (language.constexpr, int, float)):
|
|
470
|
-
return semantic.to_tensor(value
|
|
555
|
+
return self.semantic.to_tensor(value)
|
|
471
556
|
return value
|
|
472
557
|
|
|
473
558
|
ret_value = decay(ret_value)
|
|
@@ -522,8 +607,11 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
522
607
|
self.module.push_back(self.fn)
|
|
523
608
|
entry = self.fn.add_entry_block()
|
|
524
609
|
arg_values = self.prototype.deserialize(self.fn)
|
|
610
|
+
if self.caller_context is not None:
|
|
611
|
+
self.caller_context.initialize_callee(self.fn, self.builder)
|
|
525
612
|
# bind arguments to symbols
|
|
526
613
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
|
614
|
+
self._maybe_set_loc_to_name(arg_value, arg_name)
|
|
527
615
|
self.set_value(arg_name, arg_value)
|
|
528
616
|
insert_pt = self.builder.get_insertion_block()
|
|
529
617
|
self.builder.set_insertion_point_to_start(entry)
|
|
@@ -575,14 +663,15 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
575
663
|
return self.visit_Assign(node)
|
|
576
664
|
|
|
577
665
|
def assignTarget(self, target, value):
|
|
666
|
+
assert isinstance(target.ctx, ast.Store)
|
|
578
667
|
if isinstance(target, ast.Subscript):
|
|
579
|
-
assert target.ctx.__class__.__name__ == "Store"
|
|
580
668
|
return self.visit_Subscript_Store(target, value)
|
|
581
669
|
if isinstance(target, ast.Tuple):
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
self.set_value(self.visit(name), value.values[i])
|
|
670
|
+
for i, target in enumerate(target.elts):
|
|
671
|
+
self.assignTarget(target, value.values[i])
|
|
585
672
|
return
|
|
673
|
+
if isinstance(target, ast.Attribute):
|
|
674
|
+
raise NotImplementedError("Attribute assignment is not supported in triton")
|
|
586
675
|
assert isinstance(target, ast.Name)
|
|
587
676
|
self.set_value(self.visit(target), value)
|
|
588
677
|
|
|
@@ -596,21 +685,26 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
596
685
|
if value is not None and \
|
|
597
686
|
not _is_triton_value(value) and \
|
|
598
687
|
not isinstance(value, native_nontensor_types):
|
|
599
|
-
value = semantic.to_tensor(value
|
|
688
|
+
value = self.semantic.to_tensor(value)
|
|
600
689
|
return value
|
|
601
690
|
|
|
602
|
-
values = _sanitize_value(self.visit(node.value))
|
|
603
691
|
targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
|
|
604
692
|
assert len(targets) == 1
|
|
605
|
-
|
|
693
|
+
target = targets[0]
|
|
694
|
+
if isinstance(target, ast.Name):
|
|
695
|
+
with self._name_loc_prefix(target.id):
|
|
696
|
+
values = _sanitize_value(self.visit(node.value))
|
|
697
|
+
else:
|
|
698
|
+
values = _sanitize_value(self.visit(node.value))
|
|
699
|
+
self.assignTarget(target, values)
|
|
606
700
|
|
|
607
701
|
def visit_AugAssign(self, node):
|
|
608
|
-
|
|
609
|
-
lhs = ast.
|
|
702
|
+
lhs = copy.deepcopy(node.target)
|
|
703
|
+
lhs.ctx = ast.Load()
|
|
610
704
|
rhs = ast.BinOp(lhs, node.op, node.value)
|
|
611
705
|
assign = ast.Assign(targets=[node.target], value=rhs)
|
|
612
706
|
self.visit(assign)
|
|
613
|
-
return self.
|
|
707
|
+
return self.visit(lhs)
|
|
614
708
|
|
|
615
709
|
def visit_Name(self, node):
|
|
616
710
|
if type(node.ctx) is ast.Store:
|
|
@@ -630,10 +724,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
630
724
|
def _apply_binary_method(self, method_name, lhs, rhs):
|
|
631
725
|
# TODO: raise something meaningful if getattr fails below, esp for reverse method
|
|
632
726
|
if _is_triton_tensor(lhs):
|
|
633
|
-
return getattr(lhs, method_name)(rhs,
|
|
727
|
+
return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
|
|
634
728
|
if _is_triton_tensor(rhs):
|
|
635
729
|
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
|
|
636
|
-
return getattr(rhs, reverse_method_name)(lhs,
|
|
730
|
+
return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
|
|
731
|
+
if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
|
|
732
|
+
lhs = constexpr(lhs)
|
|
637
733
|
return getattr(lhs, method_name)(rhs)
|
|
638
734
|
|
|
639
735
|
def visit_BinOp(self, node):
|
|
@@ -666,8 +762,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
666
762
|
self.visit_compound_statement(node.body)
|
|
667
763
|
then_block = self.builder.get_insertion_block()
|
|
668
764
|
then_defs = self.local_defs.copy()
|
|
765
|
+
then_vals = self.lscope.copy()
|
|
669
766
|
# else block
|
|
670
767
|
else_defs = {}
|
|
768
|
+
else_vals = liveins.copy()
|
|
671
769
|
if node.orelse:
|
|
672
770
|
self.builder.set_insertion_point_to_start(else_block)
|
|
673
771
|
self.lscope = liveins.copy()
|
|
@@ -675,26 +773,29 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
675
773
|
self.visit_compound_statement(node.orelse)
|
|
676
774
|
else_defs = self.local_defs.copy()
|
|
677
775
|
else_block = self.builder.get_insertion_block()
|
|
776
|
+
else_vals = self.lscope.copy()
|
|
678
777
|
|
|
679
778
|
# update block arguments
|
|
680
779
|
names = []
|
|
681
780
|
# variables in livein whose value is updated in `if`
|
|
682
|
-
for name in liveins:
|
|
781
|
+
for name, value in liveins.items():
|
|
782
|
+
# livein variable changed value in either then or else
|
|
783
|
+
if not _is_triton_value(value):
|
|
784
|
+
continue
|
|
785
|
+
then_handles = flatten_values_to_ir([then_vals[name]])
|
|
786
|
+
else_handles = flatten_values_to_ir([else_vals[name]])
|
|
787
|
+
if then_handles == else_handles:
|
|
788
|
+
continue
|
|
789
|
+
names.append(name)
|
|
790
|
+
then_defs[name] = then_vals[name]
|
|
791
|
+
else_defs[name] = else_vals[name]
|
|
683
792
|
# check type
|
|
684
793
|
for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
if name in then_defs or name in else_defs:
|
|
691
|
-
names.append(name)
|
|
692
|
-
# variable defined in then but not in else
|
|
693
|
-
if name in then_defs and name not in else_defs:
|
|
694
|
-
else_defs[name] = liveins[name]
|
|
695
|
-
# variable defined in else but not in then
|
|
696
|
-
if name in else_defs and name not in then_defs:
|
|
697
|
-
then_defs[name] = liveins[name]
|
|
794
|
+
type_equal = type(defs[name]) == type(value) # noqa: E721
|
|
795
|
+
assert type_equal and defs[name].type == value.type, \
|
|
796
|
+
f'initial value for `{name}` is of type {value}, '\
|
|
797
|
+
f'but the {block_name} block redefines it as {defs[name]}'
|
|
798
|
+
|
|
698
799
|
# variables that are both in then and else but not in liveins
|
|
699
800
|
# TODO: could probably be cleaned up
|
|
700
801
|
for name in sorted(then_defs.keys() & else_defs.keys()):
|
|
@@ -761,6 +862,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
761
862
|
self.visit_then_else_blocks(node, liveins, then_block, else_block)
|
|
762
863
|
# create if op
|
|
763
864
|
then_handles = flatten_values_to_ir(then_defs[name] for name in names)
|
|
865
|
+
for name, val in zip(names, then_handles):
|
|
866
|
+
self._maybe_set_loc_to_name(val, name)
|
|
764
867
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
765
868
|
if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
|
|
766
869
|
then_block.merge_block_before(if_op.get_then_block())
|
|
@@ -774,6 +877,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
774
877
|
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
|
775
878
|
if len(names) > 0:
|
|
776
879
|
else_handles = flatten_values_to_ir(else_defs[name] for name in names)
|
|
880
|
+
for name, val in zip(names, else_handles):
|
|
881
|
+
self._maybe_set_loc_to_name(val, name)
|
|
777
882
|
self.builder.create_yield_op(else_handles)
|
|
778
883
|
# update values
|
|
779
884
|
res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
|
|
@@ -786,14 +891,18 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
786
891
|
cond = self.visit(node.test)
|
|
787
892
|
|
|
788
893
|
if _is_triton_tensor(cond):
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
if
|
|
894
|
+
if _is_non_scalar_tensor(cond):
|
|
895
|
+
raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
|
|
896
|
+
if cond.type.is_block():
|
|
897
|
+
warnings.warn(
|
|
898
|
+
"If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
|
|
899
|
+
% ast.unparse(node.test))
|
|
900
|
+
cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
|
|
901
|
+
cond = cond.to(language.int1, _semantic=self.semantic)
|
|
902
|
+
if ContainsReturnChecker(self.gscope).visit(node):
|
|
792
903
|
if self.scf_stack:
|
|
793
904
|
raise self._unsupported(
|
|
794
|
-
node, "Cannot have `return` statements inside `while` or `for` statements in triton
|
|
795
|
-
"(note that this also applies to `return` statements that are inside functions "
|
|
796
|
-
"transitively called from within `while`/`for` statements)")
|
|
905
|
+
node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
|
|
797
906
|
self.visit_if_top_level(cond, node)
|
|
798
907
|
else:
|
|
799
908
|
self.visit_if_scf(cond, node)
|
|
@@ -812,21 +921,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
812
921
|
def visit_IfExp(self, node):
|
|
813
922
|
cond = self.visit(node.test)
|
|
814
923
|
if _is_triton_tensor(cond):
|
|
815
|
-
cond = cond.to(language.int1,
|
|
924
|
+
cond = cond.to(language.int1, _semantic=self.semantic)
|
|
816
925
|
# TODO: Deal w/ more complicated return types (e.g tuple)
|
|
817
926
|
with enter_sub_region(self):
|
|
818
927
|
ip, last_loc = self._get_insertion_point_and_loc()
|
|
819
928
|
|
|
820
929
|
then_block = self.builder.create_block()
|
|
821
930
|
self.builder.set_insertion_point_to_start(then_block)
|
|
822
|
-
then_val = semantic.to_tensor(self.visit(node.body)
|
|
931
|
+
then_val = self.semantic.to_tensor(self.visit(node.body))
|
|
823
932
|
then_block = self.builder.get_insertion_block()
|
|
824
933
|
|
|
825
934
|
else_block = self.builder.create_block()
|
|
826
935
|
self.builder.set_insertion_point_to_start(else_block)
|
|
827
936
|
# do not need to reset lscope since
|
|
828
937
|
# ternary expressions cannot define new variables
|
|
829
|
-
else_val = semantic.to_tensor(self.visit(node.orelse)
|
|
938
|
+
else_val = self.semantic.to_tensor(self.visit(node.orelse))
|
|
830
939
|
else_block = self.builder.get_insertion_block()
|
|
831
940
|
|
|
832
941
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
@@ -862,6 +971,37 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
862
971
|
else:
|
|
863
972
|
return self.visit(node.orelse)
|
|
864
973
|
|
|
974
|
+
def visit_With(self, node):
|
|
975
|
+
# Lower `with` statements by constructing context managers and calling their enter/exit hooks
|
|
976
|
+
# Instantiate each context manager with builder injection
|
|
977
|
+
if len(node.items) == 1: # Handle async_task
|
|
978
|
+
context = node.items[0].context_expr
|
|
979
|
+
withitemClass = self.visit(context.func)
|
|
980
|
+
if withitemClass == language.async_task:
|
|
981
|
+
args = [self.visit(arg) for arg in context.args]
|
|
982
|
+
with withitemClass(*args, _builder=self.builder):
|
|
983
|
+
self.visit_compound_statement(node.body)
|
|
984
|
+
return
|
|
985
|
+
|
|
986
|
+
cm_list = []
|
|
987
|
+
for item in node.items:
|
|
988
|
+
call = item.context_expr
|
|
989
|
+
fn = self.visit(call.func)
|
|
990
|
+
args = [self.visit(arg) for arg in call.args]
|
|
991
|
+
kws = dict(self.visit(kw) for kw in call.keywords)
|
|
992
|
+
cm = fn(*args, _semantic=self.semantic, **kws)
|
|
993
|
+
cm_list.append(cm)
|
|
994
|
+
for cm, item in zip(cm_list, node.items):
|
|
995
|
+
res = cm.__enter__()
|
|
996
|
+
if item.optional_vars is not None:
|
|
997
|
+
var_name = self.visit(item.optional_vars)
|
|
998
|
+
self.set_value(var_name, res)
|
|
999
|
+
if ContainsReturnChecker(self.gscope).visit(node):
|
|
1000
|
+
raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
|
|
1001
|
+
self.visit_compound_statement(node.body)
|
|
1002
|
+
for cm in reversed(cm_list):
|
|
1003
|
+
cm.__exit__(None, None, None)
|
|
1004
|
+
|
|
865
1005
|
def visit_Pass(self, node):
|
|
866
1006
|
pass
|
|
867
1007
|
|
|
@@ -892,10 +1032,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
892
1032
|
if fn is None:
|
|
893
1033
|
raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
|
|
894
1034
|
if _is_triton_tensor(operand):
|
|
895
|
-
return getattr(operand, fn)(
|
|
1035
|
+
return getattr(operand, fn)(_semantic=self.semantic)
|
|
896
1036
|
try:
|
|
897
1037
|
return getattr(operand, fn)()
|
|
898
1038
|
except AttributeError:
|
|
1039
|
+
if fn == "__not__":
|
|
1040
|
+
return constexpr(not operand)
|
|
899
1041
|
raise self._unsupported(
|
|
900
1042
|
node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
|
|
901
1043
|
|
|
@@ -904,46 +1046,26 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
904
1046
|
}
|
|
905
1047
|
|
|
906
1048
|
def _verify_loop_carried_variable(self, name, loop_val, live_val):
|
|
907
|
-
assert _is_triton_value(loop_val), f'cannot reassign
|
|
908
|
-
assert _is_triton_value(live_val), f'cannot
|
|
909
|
-
assert type(loop_val) is type(live_val),
|
|
1049
|
+
assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
|
|
1050
|
+
assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
|
|
1051
|
+
assert type(loop_val) is type(live_val), (
|
|
1052
|
+
f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
|
|
910
1053
|
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
|
|
911
1054
|
f'Loop-carried variable {name} has initial type {live_val.type} '\
|
|
912
1055
|
f'but is re-assigned to {loop_val.type} in loop! '\
|
|
913
1056
|
f'Please make sure that the type stays consistent.'
|
|
914
1057
|
|
|
1058
|
+
def visit_withitem(self, node):
|
|
1059
|
+
return self.visit(node.context_expr)
|
|
1060
|
+
|
|
915
1061
|
def visit_While(self, node):
|
|
916
1062
|
with enter_sub_region(self) as sr:
|
|
917
1063
|
liveins, insert_block = sr
|
|
918
1064
|
ip, last_loc = self._get_insertion_point_and_loc()
|
|
919
1065
|
|
|
920
|
-
|
|
921
|
-
# loop_block = self.builder.create_block()
|
|
922
|
-
dummy = self.builder.create_block()
|
|
923
|
-
self.builder.set_insertion_point_to_start(dummy)
|
|
924
|
-
self.scf_stack.append(node)
|
|
925
|
-
self.visit_compound_statement(node.body)
|
|
926
|
-
self.scf_stack.pop()
|
|
927
|
-
loop_defs = self.local_defs
|
|
928
|
-
dummy.erase()
|
|
929
|
-
|
|
930
|
-
# collect loop-carried values
|
|
931
|
-
names = []
|
|
932
|
-
init_args = []
|
|
933
|
-
for name in loop_defs:
|
|
934
|
-
if name in liveins:
|
|
935
|
-
# We should not def new constexpr
|
|
936
|
-
loop_val = loop_defs[name]
|
|
937
|
-
live_val = liveins[name]
|
|
938
|
-
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
939
|
-
|
|
940
|
-
# these are loop-carried values
|
|
941
|
-
names.append(name)
|
|
942
|
-
init_args.append(live_val)
|
|
1066
|
+
names, init_handles, init_fe_tys = self._find_carries(node, liveins)
|
|
943
1067
|
|
|
944
|
-
init_handles = flatten_values_to_ir(init_args)
|
|
945
1068
|
init_tys = [h.get_type() for h in init_handles]
|
|
946
|
-
init_fe_tys = [a.type for a in init_args]
|
|
947
1069
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
948
1070
|
while_op = self.builder.create_while_op(init_tys, init_handles)
|
|
949
1071
|
# merge the condition region
|
|
@@ -954,7 +1076,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
954
1076
|
for name, val in zip(names, condition_args):
|
|
955
1077
|
self.lscope[name] = val
|
|
956
1078
|
self.local_defs[name] = val
|
|
1079
|
+
self._maybe_set_loc_to_name(val, name)
|
|
957
1080
|
cond = self.visit(node.test)
|
|
1081
|
+
if isinstance(cond, language.condition):
|
|
1082
|
+
if cond.disable_licm:
|
|
1083
|
+
while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
|
|
1084
|
+
cond = cond.condition
|
|
958
1085
|
self.builder.set_insertion_point_to_end(before_block)
|
|
959
1086
|
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
|
960
1087
|
self.builder.create_condition_op(cond.handle, block_args)
|
|
@@ -968,16 +1095,13 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
968
1095
|
for name, val in zip(names, body_args):
|
|
969
1096
|
self.lscope[name] = val
|
|
970
1097
|
self.local_defs[name] = val
|
|
1098
|
+
self._maybe_set_loc_to_name(val, name)
|
|
971
1099
|
self.scf_stack.append(node)
|
|
972
1100
|
self.visit_compound_statement(node.body)
|
|
973
1101
|
self.scf_stack.pop()
|
|
974
|
-
loop_defs = self.local_defs
|
|
975
|
-
yields = []
|
|
976
|
-
for name in loop_defs:
|
|
977
|
-
if name in liveins:
|
|
978
|
-
loop_defs[name]._flatten_ir(yields)
|
|
979
1102
|
|
|
980
|
-
self.
|
|
1103
|
+
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
|
|
1104
|
+
self.builder.create_yield_op(yield_handles)
|
|
981
1105
|
|
|
982
1106
|
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
|
983
1107
|
result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
|
|
@@ -985,25 +1109,22 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
985
1109
|
for name, new_def in zip(names, result_vals):
|
|
986
1110
|
self.lscope[name] = new_def
|
|
987
1111
|
self.local_defs[name] = new_def
|
|
1112
|
+
self._maybe_set_loc_to_name(new_def, name)
|
|
988
1113
|
|
|
989
1114
|
for stmt in node.orelse:
|
|
990
1115
|
assert False, "Not implemented"
|
|
991
1116
|
ast.NodeVisitor.generic_visit(self, stmt)
|
|
992
1117
|
|
|
993
1118
|
def visit_Subscript_Load(self, node):
|
|
994
|
-
assert node.ctx.
|
|
1119
|
+
assert isinstance(node.ctx, ast.Load)
|
|
995
1120
|
lhs = self.visit(node.value)
|
|
996
1121
|
slices = self.visit(node.slice)
|
|
997
|
-
if
|
|
998
|
-
return lhs.__getitem__
|
|
1122
|
+
if _is_triton_value(lhs):
|
|
1123
|
+
return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
|
|
999
1124
|
return lhs[slices]
|
|
1000
1125
|
|
|
1001
1126
|
def visit_Subscript_Store(self, node, value):
|
|
1002
|
-
|
|
1003
|
-
lhs = self.visit(node.value)
|
|
1004
|
-
slices = self.visit(node.slice)
|
|
1005
|
-
assert isinstance(lhs, language.tuple)
|
|
1006
|
-
lhs.__setitem__(slices, value)
|
|
1127
|
+
raise NotImplementedError("__setitem__ is not supported in triton")
|
|
1007
1128
|
|
|
1008
1129
|
def visit_Subscript(self, node):
|
|
1009
1130
|
return self.visit_Subscript_Load(node)
|
|
@@ -1028,6 +1149,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1028
1149
|
loop_unroll_factor = None
|
|
1029
1150
|
disallow_acc_multi_buffer = False
|
|
1030
1151
|
flatten = False
|
|
1152
|
+
warp_specialize = False
|
|
1153
|
+
disable_licm = False
|
|
1031
1154
|
if IteratorClass is language.range:
|
|
1032
1155
|
iterator = IteratorClass(*iter_args, **iter_kwargs)
|
|
1033
1156
|
# visit iterator arguments
|
|
@@ -1040,6 +1163,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1040
1163
|
loop_unroll_factor = iterator.loop_unroll_factor
|
|
1041
1164
|
disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
|
|
1042
1165
|
flatten = iterator.flatten
|
|
1166
|
+
warp_specialize = iterator.warp_specialize
|
|
1167
|
+
disable_licm = iterator.disable_licm
|
|
1043
1168
|
elif IteratorClass is range:
|
|
1044
1169
|
# visit iterator arguments
|
|
1045
1170
|
# note: only `range` iterator is supported now
|
|
@@ -1055,14 +1180,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1055
1180
|
step = constexpr(-step.value)
|
|
1056
1181
|
negative_step = True
|
|
1057
1182
|
lb, ub = ub, lb
|
|
1058
|
-
lb = semantic.to_tensor(lb
|
|
1059
|
-
ub = semantic.to_tensor(ub
|
|
1060
|
-
step = semantic.to_tensor(step
|
|
1183
|
+
lb = self.semantic.to_tensor(lb)
|
|
1184
|
+
ub = self.semantic.to_tensor(ub)
|
|
1185
|
+
step = self.semantic.to_tensor(step)
|
|
1061
1186
|
# induction variable type
|
|
1062
1187
|
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
|
|
1063
1188
|
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
|
|
1064
|
-
iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
|
1065
|
-
iv_type = semantic.integer_promote_impl(iv_type, step.dtype)
|
|
1189
|
+
iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
|
1190
|
+
iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
|
|
1066
1191
|
iv_ir_type = iv_type.to_ir(self.builder)
|
|
1067
1192
|
iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
|
|
1068
1193
|
# lb/ub/step might be constexpr, we need to cast them to tensor
|
|
@@ -1081,34 +1206,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1081
1206
|
liveins, insert_block = sr
|
|
1082
1207
|
ip, last_loc = self._get_insertion_point_and_loc()
|
|
1083
1208
|
|
|
1084
|
-
|
|
1085
|
-
block = self.builder.create_block()
|
|
1086
|
-
self.builder.set_insertion_point_to_start(block)
|
|
1087
|
-
# dry visit loop body
|
|
1088
|
-
self.scf_stack.append(node)
|
|
1089
|
-
self.visit_compound_statement(node.body)
|
|
1090
|
-
self.scf_stack.pop()
|
|
1091
|
-
block.erase()
|
|
1092
|
-
|
|
1093
|
-
# If a variable (name) is defined in both its parent & itself, then it's
|
|
1094
|
-
# a loop-carried variable. (They must be of the same type)
|
|
1095
|
-
init_args = []
|
|
1096
|
-
yields = []
|
|
1097
|
-
names = []
|
|
1098
|
-
for name in self.local_defs:
|
|
1099
|
-
if name in liveins:
|
|
1100
|
-
loop_val = self.local_defs[name]
|
|
1101
|
-
live_val = liveins[name]
|
|
1102
|
-
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
1103
|
-
|
|
1104
|
-
names.append(name)
|
|
1105
|
-
init_args.append(live_val)
|
|
1106
|
-
yields.append(loop_val)
|
|
1209
|
+
names, init_handles, init_tys = self._find_carries(node, liveins)
|
|
1107
1210
|
|
|
1108
1211
|
# create ForOp
|
|
1109
1212
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
1110
|
-
init_handles = flatten_values_to_ir(init_args)
|
|
1111
|
-
init_tys = [v.type for v in init_args]
|
|
1112
1213
|
for_op = self.builder.create_for_op(lb, ub, step, init_handles)
|
|
1113
1214
|
if _unwrap_if_constexpr(num_stages) is not None:
|
|
1114
1215
|
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
|
|
@@ -1118,30 +1219,25 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1118
1219
|
for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
|
|
1119
1220
|
if flatten:
|
|
1120
1221
|
for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
|
|
1222
|
+
if warp_specialize:
|
|
1223
|
+
for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
|
|
1224
|
+
if disable_licm:
|
|
1225
|
+
for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
|
|
1121
1226
|
|
|
1122
1227
|
self.scf_stack.append(node)
|
|
1123
1228
|
for_op_body = for_op.get_body(0)
|
|
1124
1229
|
self.builder.set_insertion_point_to_start(for_op_body)
|
|
1125
|
-
# reset local scope to not pick up local defs from the previous dry run.
|
|
1126
|
-
self.lscope = liveins.copy()
|
|
1127
|
-
self.local_defs = {}
|
|
1128
1230
|
block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
|
|
1129
1231
|
block_args = unflatten_ir_values(block_handles, init_tys)
|
|
1130
1232
|
for name, val in zip(names, block_args):
|
|
1233
|
+
self._maybe_set_loc_to_name(val, name)
|
|
1131
1234
|
self.set_value(name, val)
|
|
1132
1235
|
self.visit_compound_statement(node.body)
|
|
1133
1236
|
self.scf_stack.pop()
|
|
1134
|
-
|
|
1135
|
-
for name in self.local_defs:
|
|
1136
|
-
if name in liveins:
|
|
1137
|
-
local = self.local_defs[name]
|
|
1138
|
-
if isinstance(local, constexpr):
|
|
1139
|
-
local = semantic.to_tensor(local, self.builder)
|
|
1140
|
-
yields.append(local)
|
|
1237
|
+
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
|
|
1141
1238
|
|
|
1142
1239
|
# create YieldOp
|
|
1143
|
-
if len(
|
|
1144
|
-
yield_handles = flatten_values_to_ir(yields)
|
|
1240
|
+
if len(yield_handles) > 0:
|
|
1145
1241
|
self.builder.create_yield_op(yield_handles)
|
|
1146
1242
|
for_op_region = for_op_body.get_parent()
|
|
1147
1243
|
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
|
@@ -1154,12 +1250,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1154
1250
|
iv = self.builder.create_add(iv, lb)
|
|
1155
1251
|
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
|
1156
1252
|
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
|
|
1253
|
+
self._maybe_set_loc_to_name(iv, node.target.id)
|
|
1157
1254
|
|
|
1158
1255
|
# update lscope & local_defs (ForOp defines new values)
|
|
1159
1256
|
result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
|
|
1160
1257
|
result_values = unflatten_ir_values(result_handles, init_tys)
|
|
1161
1258
|
for name, val in zip(names, result_values):
|
|
1162
1259
|
self.set_value(name, val)
|
|
1260
|
+
self._maybe_set_loc_to_name(val, name)
|
|
1163
1261
|
|
|
1164
1262
|
for stmt in node.orelse:
|
|
1165
1263
|
assert False, "Don't know what to do with else after for"
|
|
@@ -1180,9 +1278,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1180
1278
|
def visit_Assert(self, node) -> Any:
|
|
1181
1279
|
test = self.visit(node.test)
|
|
1182
1280
|
msg = self.visit(node.msg) if node.msg is not None else ""
|
|
1183
|
-
return language.core.device_assert(test, msg,
|
|
1281
|
+
return language.core.device_assert(test, msg, _semantic=self.semantic)
|
|
1184
1282
|
|
|
1185
|
-
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
|
1283
|
+
def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
|
|
1186
1284
|
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
|
1187
1285
|
args = [args[name] for name in fn.arg_names]
|
|
1188
1286
|
for i, arg in enumerate(args):
|
|
@@ -1193,10 +1291,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1193
1291
|
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
|
|
1194
1292
|
args_val = [get_iterable_path(args, path) for path in args_path]
|
|
1195
1293
|
# mangle
|
|
1196
|
-
|
|
1294
|
+
caller_context = caller_context or self.caller_context
|
|
1295
|
+
fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
|
|
1197
1296
|
# generate function def if necessary
|
|
1198
1297
|
if not self.module.has_function(fn_name):
|
|
1199
|
-
gscope = fn.__globals__
|
|
1200
1298
|
# If the callee is not set, we use the same debug setting as the caller
|
|
1201
1299
|
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1202
1300
|
arg_types = [
|
|
@@ -1205,15 +1303,18 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1205
1303
|
for arg in args
|
|
1206
1304
|
]
|
|
1207
1305
|
prototype = ASTFunction([], arg_types, args_cst, dict())
|
|
1208
|
-
generator = CodeGenerator(self.context, prototype,
|
|
1306
|
+
generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
|
|
1209
1307
|
function_name=fn_name, function_types=self.function_ret_types,
|
|
1210
1308
|
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
|
1211
1309
|
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
|
|
1212
|
-
module_map=self.builder.module_map
|
|
1310
|
+
module_map=self.builder.module_map, caller_context=caller_context,
|
|
1311
|
+
is_gluon=self.is_gluon)
|
|
1213
1312
|
try:
|
|
1214
1313
|
generator.visit(fn.parse())
|
|
1215
1314
|
except Exception as e:
|
|
1216
1315
|
# Wrap the error in the callee with the location of the call.
|
|
1316
|
+
if knobs.compilation.front_end_debugging:
|
|
1317
|
+
raise
|
|
1217
1318
|
raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
|
|
1218
1319
|
|
|
1219
1320
|
callee_ret_type = generator.ret_type
|
|
@@ -1221,28 +1322,30 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1221
1322
|
else:
|
|
1222
1323
|
callee_ret_type = self.function_ret_types[fn_name]
|
|
1223
1324
|
symbol = self.module.get_function(fn_name)
|
|
1224
|
-
args_val =
|
|
1325
|
+
args_val = flatten_values_to_ir(args_val)
|
|
1225
1326
|
call_op = self.builder.call(symbol, args_val)
|
|
1226
1327
|
if callee_ret_type == language.void:
|
|
1227
1328
|
return None
|
|
1228
1329
|
handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
|
|
1229
1330
|
return next(unflatten_ir_values(handles, [callee_ret_type]))
|
|
1230
1331
|
|
|
1231
|
-
def
|
|
1232
|
-
fn
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
return static_implementation(self, node)
|
|
1236
|
-
|
|
1237
|
-
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
|
1238
|
-
args = [self.visit(arg) for arg in node.args]
|
|
1239
|
-
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
|
|
1332
|
+
def call_Function(self, node, fn, args, kws):
|
|
1333
|
+
if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
|
|
1334
|
+
args.insert(0, fn.__self__)
|
|
1335
|
+
fn = fn.__func__
|
|
1240
1336
|
if isinstance(fn, JITFunction):
|
|
1241
1337
|
_check_fn_args(node, fn, args)
|
|
1242
1338
|
return self.call_JitFunction(fn, args, kws)
|
|
1243
|
-
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn)
|
|
1244
|
-
|
|
1245
|
-
|
|
1339
|
+
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
|
|
1340
|
+
fn, ConstexprFunction):
|
|
1341
|
+
extra_kwargs = dict()
|
|
1342
|
+
|
|
1343
|
+
if isinstance(fn, ConstexprFunction):
|
|
1344
|
+
sig = inspect.signature(fn.__call__)
|
|
1345
|
+
else:
|
|
1346
|
+
sig = inspect.signature(fn)
|
|
1347
|
+
if '_semantic' in sig.parameters:
|
|
1348
|
+
extra_kwargs["_semantic"] = self.semantic
|
|
1246
1349
|
if '_generator' in sig.parameters:
|
|
1247
1350
|
extra_kwargs['_generator'] = self
|
|
1248
1351
|
try:
|
|
@@ -1252,43 +1355,125 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1252
1355
|
ret = language.tuple(ret)
|
|
1253
1356
|
return ret
|
|
1254
1357
|
except Exception as e:
|
|
1358
|
+
if knobs.compilation.front_end_debugging:
|
|
1359
|
+
raise
|
|
1255
1360
|
# Normally when we raise a CompilationError, we raise it as
|
|
1256
1361
|
# `from None`, because the original fileline from the exception
|
|
1257
1362
|
# is not relevant (and often points into code_generator.py
|
|
1258
1363
|
# itself). But when calling a function, we raise as `from e` to
|
|
1259
1364
|
# preserve the traceback of the original error, which may e.g.
|
|
1260
1365
|
# be in core.py.
|
|
1261
|
-
raise CompilationError(self.jit_fn.src, node,
|
|
1366
|
+
raise CompilationError(self.jit_fn.src, node, str(e)) from e
|
|
1262
1367
|
|
|
1263
1368
|
if fn in self.builtin_namespace.values():
|
|
1264
1369
|
args = map(_unwrap_if_constexpr, args)
|
|
1265
1370
|
ret = fn(*args, **kws)
|
|
1266
|
-
|
|
1371
|
+
|
|
1372
|
+
def wrap_constexpr(x):
|
|
1373
|
+
if _is_triton_value(x):
|
|
1374
|
+
return x
|
|
1375
|
+
return constexpr(x)
|
|
1376
|
+
|
|
1377
|
+
if isinstance(ret, (builtins.tuple, language.tuple)):
|
|
1378
|
+
return _apply_to_tuple_values(ret, wrap_constexpr)
|
|
1379
|
+
return wrap_constexpr(ret)
|
|
1380
|
+
|
|
1381
|
+
def call_Method(self, node, fn, fn_self, args, kws):
|
|
1382
|
+
if isinstance(fn, JITFunction):
|
|
1383
|
+
args.insert(0, fn_self)
|
|
1384
|
+
return self.call_Function(node, fn, args, kws)
|
|
1385
|
+
|
|
1386
|
+
def visit_Call(self, node):
|
|
1387
|
+
fn = _unwrap_if_constexpr(self.visit(node.func))
|
|
1388
|
+
if not isinstance(fn, BoundJITMethod):
|
|
1389
|
+
static_implementation = self.statically_implemented_functions.get(fn)
|
|
1390
|
+
if static_implementation is not None:
|
|
1391
|
+
return static_implementation(self, node)
|
|
1392
|
+
|
|
1393
|
+
mur = getattr(fn, '_must_use_result', False)
|
|
1394
|
+
if mur and getattr(node, '_is_unused', False):
|
|
1395
|
+
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
|
|
1396
|
+
if isinstance(mur, str):
|
|
1397
|
+
error_message.append(mur)
|
|
1398
|
+
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
|
|
1399
|
+
|
|
1400
|
+
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
|
1401
|
+
args = [self.visit(arg) for arg in node.args]
|
|
1402
|
+
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
|
|
1403
|
+
|
|
1404
|
+
return self.call_Function(node, fn, args, kws)
|
|
1267
1405
|
|
|
1268
1406
|
def visit_Constant(self, node):
|
|
1269
1407
|
return constexpr(node.value)
|
|
1270
1408
|
|
|
1271
1409
|
def visit_BoolOp(self, node: ast.BoolOp):
|
|
1272
|
-
if len(node.values) != 2:
|
|
1273
|
-
raise self._unsupported(
|
|
1274
|
-
node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
|
1275
|
-
lhs = self.visit(node.values[0])
|
|
1276
|
-
rhs = self.visit(node.values[1])
|
|
1277
1410
|
method_name = self._method_name_for_bool_op.get(type(node.op))
|
|
1278
1411
|
if method_name is None:
|
|
1279
1412
|
raise self._unsupported(
|
|
1280
1413
|
node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
|
1281
|
-
|
|
1414
|
+
|
|
1415
|
+
nontrivial_values = []
|
|
1416
|
+
|
|
1417
|
+
for subnode in node.values:
|
|
1418
|
+
# we visit the values in order, executing their side-effects
|
|
1419
|
+
# and possibly early-exiting:
|
|
1420
|
+
value = self.visit(subnode)
|
|
1421
|
+
if not _is_triton_tensor(value):
|
|
1422
|
+
# this is a constexpr, so we might be able to short-circuit:
|
|
1423
|
+
bv = bool(value)
|
|
1424
|
+
if (bv is False) and (method_name == "logical_and"):
|
|
1425
|
+
# value is falsey so return that:
|
|
1426
|
+
return value
|
|
1427
|
+
if (bv is True) and (method_name == "logical_or"):
|
|
1428
|
+
# value is truthy so return that:
|
|
1429
|
+
return value
|
|
1430
|
+
# otherwise, our constexpr has no effect on the output of the
|
|
1431
|
+
# expression so we do not append it to nontrivial_values.
|
|
1432
|
+
else:
|
|
1433
|
+
if value.type.is_block():
|
|
1434
|
+
lineno = getattr(node, "lineno", None)
|
|
1435
|
+
if lineno is not None:
|
|
1436
|
+
lineno += self.begin_line
|
|
1437
|
+
warnings.warn_explicit(
|
|
1438
|
+
"Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
|
|
1439
|
+
category=UserWarning,
|
|
1440
|
+
filename=self.file_name,
|
|
1441
|
+
lineno=lineno,
|
|
1442
|
+
source=ast.unparse(node),
|
|
1443
|
+
)
|
|
1444
|
+
# not a constexpr so we must append it:
|
|
1445
|
+
nontrivial_values.append(value)
|
|
1446
|
+
|
|
1447
|
+
if len(nontrivial_values) == 0:
|
|
1448
|
+
# the semantics of a disjunction of falsey values or conjunction
|
|
1449
|
+
# of truthy values is to return the final value:
|
|
1450
|
+
nontrivial_values.append(value)
|
|
1451
|
+
|
|
1452
|
+
while len(nontrivial_values) >= 2:
|
|
1453
|
+
rhs = nontrivial_values.pop()
|
|
1454
|
+
lhs = nontrivial_values.pop()
|
|
1455
|
+
res = self._apply_binary_method(method_name, lhs, rhs)
|
|
1456
|
+
nontrivial_values.append(res)
|
|
1457
|
+
|
|
1458
|
+
assert len(nontrivial_values) == 1
|
|
1459
|
+
return nontrivial_values[0]
|
|
1282
1460
|
|
|
1283
1461
|
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
|
1284
1462
|
|
|
1285
1463
|
def visit_Attribute(self, node):
|
|
1286
1464
|
lhs = self.visit(node.value)
|
|
1287
1465
|
if _is_triton_tensor(lhs) and node.attr == "T":
|
|
1288
|
-
return semantic.permute(lhs, (1, 0)
|
|
1289
|
-
|
|
1466
|
+
return self.semantic.permute(lhs, (1, 0))
|
|
1467
|
+
# NOTE: special case ".value" for BC
|
|
1468
|
+
if isinstance(lhs, constexpr) and node.attr not in ("value", "type"):
|
|
1469
|
+
lhs = lhs.value
|
|
1470
|
+
attr = getattr(lhs, node.attr)
|
|
1471
|
+
if _is_triton_value(lhs) and isinstance(attr, JITFunction):
|
|
1472
|
+
return BoundJITMethod(lhs, attr)
|
|
1473
|
+
return attr
|
|
1290
1474
|
|
|
1291
1475
|
def visit_Expr(self, node):
|
|
1476
|
+
node.value._is_unused = True
|
|
1292
1477
|
ast.NodeVisitor.generic_visit(self, node)
|
|
1293
1478
|
|
|
1294
1479
|
def visit_NoneType(self, node):
|
|
@@ -1324,13 +1509,19 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1324
1509
|
last_loc = self.builder.get_loc()
|
|
1325
1510
|
self.cur_node = node
|
|
1326
1511
|
if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
|
|
1327
|
-
self.builder.
|
|
1512
|
+
here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
|
|
1513
|
+
if self.name_loc_as_prefix is not None:
|
|
1514
|
+
self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
|
|
1515
|
+
else:
|
|
1516
|
+
self.builder.set_loc(here_loc)
|
|
1328
1517
|
last_loc = self.builder.get_loc()
|
|
1329
1518
|
try:
|
|
1330
1519
|
ret = super().visit(node)
|
|
1331
1520
|
except CompilationError:
|
|
1332
1521
|
raise
|
|
1333
1522
|
except Exception as e:
|
|
1523
|
+
if knobs.compilation.front_end_debugging:
|
|
1524
|
+
raise
|
|
1334
1525
|
# Wrap the error in a CompilationError which contains the source
|
|
1335
1526
|
# of the @jit function.
|
|
1336
1527
|
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
|
|
@@ -1378,16 +1569,29 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1378
1569
|
|
|
1379
1570
|
return ret
|
|
1380
1571
|
|
|
1572
|
+
from ..experimental.gluon import language as ttgl
|
|
1381
1573
|
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
|
|
1382
1574
|
language.core.static_assert: execute_static_assert,
|
|
1383
1575
|
language.core.static_print: static_executor(print),
|
|
1576
|
+
ttgl.static_assert: execute_static_assert,
|
|
1577
|
+
ttgl.static_print: static_executor(print),
|
|
1384
1578
|
int: static_executor(int),
|
|
1385
1579
|
len: static_executor(len),
|
|
1386
1580
|
}
|
|
1387
1581
|
|
|
1388
1582
|
|
|
1389
|
-
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
|
|
1390
|
-
arg_types =
|
|
1583
|
+
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
|
|
1584
|
+
arg_types = [None] * len(fn.arg_names)
|
|
1585
|
+
const_iter = iter(src.constants.items())
|
|
1586
|
+
kc, vc = next(const_iter, (None, None))
|
|
1587
|
+
|
|
1588
|
+
for i, (ks, v) in enumerate(src.signature.items()):
|
|
1589
|
+
idx = fn.arg_names.index(ks)
|
|
1590
|
+
cexpr = None
|
|
1591
|
+
if kc is not None and kc[0] == i:
|
|
1592
|
+
cexpr = vc
|
|
1593
|
+
kc, vc = next(const_iter, (None, None))
|
|
1594
|
+
arg_types[idx] = str_to_ty(v, cexpr)
|
|
1391
1595
|
prototype = ASTFunction([], arg_types, src.constants, src.attrs)
|
|
1392
1596
|
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1393
1597
|
# query function representation
|
|
@@ -1396,11 +1600,15 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
|
|
|
1396
1600
|
constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
|
|
1397
1601
|
signature = src.signature
|
|
1398
1602
|
proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
|
|
1399
|
-
generator = CodeGenerator(context, prototype, gscope=fn.
|
|
1400
|
-
is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
|
|
1401
|
-
codegen_fns=codegen_fns, module_map=module_map)
|
|
1603
|
+
generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
|
|
1604
|
+
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
|
|
1605
|
+
codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
|
|
1402
1606
|
generator.visit(fn.parse())
|
|
1403
|
-
|
|
1607
|
+
module = generator.module
|
|
1404
1608
|
# module takes ownership of the context
|
|
1405
|
-
|
|
1406
|
-
|
|
1609
|
+
module.context = context
|
|
1610
|
+
if not module.verify_with_diagnostics():
|
|
1611
|
+
if not fn.is_gluon():
|
|
1612
|
+
print(module)
|
|
1613
|
+
raise RuntimeError("error encountered during parsing")
|
|
1614
|
+
return module
|