triton-windows 3.3.0.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
- triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import copy
|
|
2
3
|
import inspect
|
|
3
4
|
import re
|
|
4
5
|
import warnings
|
|
5
|
-
import os
|
|
6
6
|
import textwrap
|
|
7
7
|
import itertools
|
|
8
|
+
from dataclasses import dataclass
|
|
8
9
|
from types import ModuleType
|
|
9
10
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
|
|
10
11
|
|
|
11
|
-
from .. import language
|
|
12
|
-
from .._C.libtriton import ir
|
|
13
|
-
from ..language import constexpr,
|
|
14
|
-
from ..language.core import _unwrap_if_constexpr,
|
|
15
|
-
from ..runtime.jit import get_jit_fn_file_line
|
|
12
|
+
from .. import knobs, language
|
|
13
|
+
from .._C.libtriton import ir, gluon_ir
|
|
14
|
+
from ..language import constexpr, str_to_ty, tensor
|
|
15
|
+
from ..language.core import _unwrap_if_constexpr, base_value, base_type
|
|
16
|
+
from ..runtime.jit import get_jit_fn_file_line, get_full_name
|
|
16
17
|
# ideally we wouldn't need any runtime component
|
|
17
18
|
from ..runtime import JITFunction
|
|
18
19
|
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
|
|
@@ -27,29 +28,9 @@ def check_identifier_legality(name, type):
|
|
|
27
28
|
return name
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def mangle_ty(ty):
|
|
31
|
-
if ty.is_tuple():
|
|
32
|
-
return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
|
|
33
|
-
if ty.is_ptr():
|
|
34
|
-
return 'P' + mangle_ty(ty.element_ty)
|
|
35
|
-
if ty.is_int():
|
|
36
|
-
SIGNED = language.dtype.SIGNEDNESS.SIGNED
|
|
37
|
-
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
|
|
38
|
-
return prefix + str(ty.int_bitwidth)
|
|
39
|
-
if ty.is_floating():
|
|
40
|
-
return str(ty)
|
|
41
|
-
if ty.is_block():
|
|
42
|
-
elt = mangle_ty(ty.scalar)
|
|
43
|
-
shape = '_'.join(map(str, ty.shape))
|
|
44
|
-
return f'{elt}S{shape}S'
|
|
45
|
-
if ty.is_void():
|
|
46
|
-
return 'V'
|
|
47
|
-
raise TypeError(f'Unsupported type {ty}')
|
|
48
|
-
|
|
49
|
-
|
|
50
31
|
def mangle_fn(name, arg_tys, constants):
|
|
51
32
|
# doesn't mangle ret type, which must be a function of arg tys
|
|
52
|
-
mangled_arg_names = '_'.join([
|
|
33
|
+
mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
|
|
53
34
|
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
|
|
54
35
|
mangled_constants = mangled_constants.replace('.', '_d_')
|
|
55
36
|
mangled_constants = mangled_constants.replace("'", '_sq_')
|
|
@@ -68,11 +49,11 @@ def _is_triton_tensor(o: Any) -> bool:
|
|
|
68
49
|
|
|
69
50
|
|
|
70
51
|
def _is_constexpr(o: Any) -> bool:
|
|
71
|
-
return o is None or isinstance(o, (constexpr, language.core.dtype))
|
|
52
|
+
return o is None or isinstance(o, (constexpr, language.core.dtype, JITFunction))
|
|
72
53
|
|
|
73
54
|
|
|
74
|
-
def
|
|
75
|
-
return _is_triton_tensor(o) and (
|
|
55
|
+
def _is_non_scalar_tensor(o: Any) -> bool:
|
|
56
|
+
return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
|
|
76
57
|
|
|
77
58
|
|
|
78
59
|
def _is_list_like(o: Any) -> bool:
|
|
@@ -82,7 +63,7 @@ def _is_list_like(o: Any) -> bool:
|
|
|
82
63
|
def _check_fn_args(node, fn, args):
|
|
83
64
|
if fn.noinline:
|
|
84
65
|
for idx, arg in enumerate(args):
|
|
85
|
-
if not _is_constexpr(arg) and
|
|
66
|
+
if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
|
|
86
67
|
raise UnsupportedLanguageConstruct(
|
|
87
68
|
fn.src, node,
|
|
88
69
|
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
|
|
@@ -102,6 +83,7 @@ def _apply_to_tuple_values(value, fn):
|
|
|
102
83
|
assert False, f"Unsupported type {type(value)}"
|
|
103
84
|
|
|
104
85
|
vals = [fn(v) for v in value]
|
|
86
|
+
vals = [constexpr(v) if v is None else v for v in vals]
|
|
105
87
|
types = [v.type for v in vals]
|
|
106
88
|
return language.tuple(vals, language.tuple_type(types, fields))
|
|
107
89
|
|
|
@@ -154,10 +136,9 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
|
|
154
136
|
return any(self.visit(s) for s in body)
|
|
155
137
|
|
|
156
138
|
def _visit_function(self, fn) -> bool:
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
return ContainsReturnChecker(self.gscope).visit(fn_node)
|
|
139
|
+
# no need to check within the function as it won't cause an early return.
|
|
140
|
+
# If the function itself has unstructured control flow we may not be able to inline it causing poor performance.
|
|
141
|
+
# We should check for this and fail or emit a warning.
|
|
161
142
|
return False
|
|
162
143
|
|
|
163
144
|
def generic_visit(self, node) -> bool:
|
|
@@ -241,26 +222,26 @@ class ASTFunction:
|
|
|
241
222
|
self.constants = constants
|
|
242
223
|
self.attrs = attrs
|
|
243
224
|
|
|
244
|
-
def
|
|
245
|
-
|
|
246
|
-
for
|
|
247
|
-
if
|
|
225
|
+
def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
|
|
226
|
+
ir_types = []
|
|
227
|
+
for ty in types:
|
|
228
|
+
if ty is None:
|
|
248
229
|
continue
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
return ret_types
|
|
230
|
+
ty._flatten_ir_types(builder, ir_types)
|
|
231
|
+
return ir_types
|
|
232
|
+
|
|
233
|
+
def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
|
|
234
|
+
return self.flatten_ir_types(builder, self.ret_types)
|
|
255
235
|
|
|
256
236
|
def serialize(self, builder: ir.builder):
|
|
257
237
|
# fill up IR values in template
|
|
258
238
|
# > build function
|
|
259
239
|
is_val = lambda path, _: path not in self.constants and _ is not None
|
|
260
240
|
val_paths = list(find_paths_if(self.arg_types, is_val))
|
|
261
|
-
arg_types = [get_iterable_path(self.arg_types, path)
|
|
262
|
-
|
|
263
|
-
|
|
241
|
+
arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
|
|
242
|
+
arg_types_ir = self.flatten_ir_types(builder, arg_types)
|
|
243
|
+
ret_types_ir = self.return_types_ir(builder)
|
|
244
|
+
return builder.get_function_ty(arg_types_ir, ret_types_ir)
|
|
264
245
|
|
|
265
246
|
def deserialize(self, fn):
|
|
266
247
|
# create "template"
|
|
@@ -272,19 +253,18 @@ class ASTFunction:
|
|
|
272
253
|
vals = make_template(self.arg_types)
|
|
273
254
|
is_val = lambda path, _: path not in self.constants and _ is not None
|
|
274
255
|
val_paths = list(find_paths_if(self.arg_types, is_val))
|
|
275
|
-
# > set attributes
|
|
276
|
-
for attr_path, attr_specs in self.attrs.items():
|
|
277
|
-
for attr_name, attr_val in attr_specs:
|
|
278
|
-
if attr_path in val_paths:
|
|
279
|
-
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
|
|
280
|
-
for i, path in enumerate(val_paths):
|
|
281
|
-
ty = get_iterable_path(self.arg_types, path)
|
|
282
|
-
if isinstance(ty, nv_tma_desc_type):
|
|
283
|
-
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
|
|
284
256
|
# > add IR values to the template
|
|
285
|
-
|
|
257
|
+
cursor = 0
|
|
258
|
+
handles = [fn.args(i) for i in range(fn.get_num_args())]
|
|
259
|
+
for path in val_paths:
|
|
286
260
|
ty = get_iterable_path(self.arg_types, path)
|
|
287
|
-
|
|
261
|
+
# > set attributes
|
|
262
|
+
attr_specs = self.attrs.get(path, [])
|
|
263
|
+
for attr_name, attr_val in attr_specs:
|
|
264
|
+
fn.set_arg_attr(cursor, attr_name, attr_val)
|
|
265
|
+
# > build frontend value
|
|
266
|
+
val, cursor = ty._unflatten_ir(handles, cursor)
|
|
267
|
+
set_iterable_path(vals, path, val)
|
|
288
268
|
# > add constexpr values to the template
|
|
289
269
|
constants = self.constants
|
|
290
270
|
for path, val in constants.items():
|
|
@@ -292,13 +272,26 @@ class ASTFunction:
|
|
|
292
272
|
return vals
|
|
293
273
|
|
|
294
274
|
|
|
275
|
+
@dataclass(frozen=True)
|
|
276
|
+
class BoundJITMethod:
|
|
277
|
+
__self__: base_value
|
|
278
|
+
__func__: JITFunction
|
|
279
|
+
|
|
280
|
+
|
|
295
281
|
class CodeGenerator(ast.NodeVisitor):
|
|
296
282
|
|
|
297
283
|
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
|
|
298
284
|
module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
|
|
299
285
|
file_name: Optional[str] = None, begin_line=0):
|
|
300
286
|
self.context = context
|
|
301
|
-
|
|
287
|
+
if jit_fn.is_gluon():
|
|
288
|
+
from triton.experimental.gluon.language._semantic import GluonSemantic
|
|
289
|
+
self.builder = gluon_ir.GluonOpBuilder(context)
|
|
290
|
+
self.semantic = GluonSemantic(self.builder)
|
|
291
|
+
else:
|
|
292
|
+
from triton.language.semantic import TritonSemantic
|
|
293
|
+
self.builder = ir.builder(context)
|
|
294
|
+
self.semantic = TritonSemantic(self.builder)
|
|
302
295
|
self.file_name = file_name
|
|
303
296
|
# node.lineno starts from 1, so we need to subtract 1
|
|
304
297
|
self.begin_line = begin_line - 1
|
|
@@ -306,7 +299,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
306
299
|
self.builder.options = options
|
|
307
300
|
# dict of functions provided by the backend. Below are the list of possible functions:
|
|
308
301
|
# Convert custom types not natively supported on HW.
|
|
309
|
-
# convert_custom_types(
|
|
302
|
+
# convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None)
|
|
310
303
|
self.builder.codegen_fns = codegen_fns
|
|
311
304
|
self.builder.module_map = {} if module_map is None else module_map
|
|
312
305
|
self.module = self.builder.create_module() if module is None else module
|
|
@@ -329,6 +322,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
329
322
|
self.jit_fn = jit_fn
|
|
330
323
|
# TODO: we currently generate illegal names for non-kernel functions involving constexprs!
|
|
331
324
|
if is_kernel:
|
|
325
|
+
function_name = function_name[function_name.rfind('.') + 1:]
|
|
332
326
|
function_name = check_identifier_legality(function_name, "function")
|
|
333
327
|
self.function_name = function_name
|
|
334
328
|
self.is_kernel = is_kernel
|
|
@@ -345,7 +339,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
345
339
|
# special handling.
|
|
346
340
|
self.visiting_arg_default_value = False
|
|
347
341
|
|
|
348
|
-
builtin_namespace: Dict[str, Any] = {
|
|
342
|
+
builtin_namespace: Dict[str, Any] = {
|
|
343
|
+
_.__name__: _
|
|
344
|
+
for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
|
|
345
|
+
}
|
|
349
346
|
builtin_namespace.update((
|
|
350
347
|
('print', language.core.device_print),
|
|
351
348
|
('min', language.minimum),
|
|
@@ -378,11 +375,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
378
375
|
# But actually a bunch of other things, such as module imports, are
|
|
379
376
|
# technically Python globals. We have to allow these too!
|
|
380
377
|
if any([
|
|
381
|
-
val is absent,
|
|
378
|
+
val is absent,
|
|
379
|
+
name in self.builtin_namespace, #
|
|
382
380
|
type(val) is ModuleType, #
|
|
383
381
|
isinstance(val, JITFunction), #
|
|
384
382
|
getattr(val, "__triton_builtin__", False), #
|
|
383
|
+
getattr(val, "__triton_aggregate__", False), #
|
|
385
384
|
getattr(val, "__module__", "").startswith("triton.language"), #
|
|
385
|
+
getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), #
|
|
386
386
|
isinstance(val, language.dtype), #
|
|
387
387
|
_is_namedtuple(val),
|
|
388
388
|
self._is_constexpr_global(name), #
|
|
@@ -390,7 +390,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
390
390
|
# because you should be able to do
|
|
391
391
|
# @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
|
|
392
392
|
self.visiting_arg_default_value, #
|
|
393
|
-
|
|
393
|
+
knobs.compilation.allow_non_constexpr_globals,
|
|
394
394
|
]):
|
|
395
395
|
return val
|
|
396
396
|
raise NameError(
|
|
@@ -467,7 +467,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
467
467
|
if isinstance(value, language.tuple):
|
|
468
468
|
return _apply_to_tuple_values(value, decay)
|
|
469
469
|
elif isinstance(value, (language.constexpr, int, float)):
|
|
470
|
-
return semantic.to_tensor(value
|
|
470
|
+
return self.semantic.to_tensor(value)
|
|
471
471
|
return value
|
|
472
472
|
|
|
473
473
|
ret_value = decay(ret_value)
|
|
@@ -575,13 +575,16 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
575
575
|
return self.visit_Assign(node)
|
|
576
576
|
|
|
577
577
|
def assignTarget(self, target, value):
|
|
578
|
+
assert isinstance(target.ctx, ast.Store)
|
|
578
579
|
if isinstance(target, ast.Subscript):
|
|
579
|
-
assert target.ctx.__class__.__name__ == "Store"
|
|
580
580
|
return self.visit_Subscript_Store(target, value)
|
|
581
581
|
if isinstance(target, ast.Tuple):
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
582
|
+
for i, target in enumerate(target.elts):
|
|
583
|
+
self.assignTarget(target, value.values[i])
|
|
584
|
+
return
|
|
585
|
+
if isinstance(target, ast.Attribute):
|
|
586
|
+
base = self.visit(target.value)
|
|
587
|
+
setattr(base, target.attr, value)
|
|
585
588
|
return
|
|
586
589
|
assert isinstance(target, ast.Name)
|
|
587
590
|
self.set_value(self.visit(target), value)
|
|
@@ -596,7 +599,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
596
599
|
if value is not None and \
|
|
597
600
|
not _is_triton_value(value) and \
|
|
598
601
|
not isinstance(value, native_nontensor_types):
|
|
599
|
-
value = semantic.to_tensor(value
|
|
602
|
+
value = self.semantic.to_tensor(value)
|
|
600
603
|
return value
|
|
601
604
|
|
|
602
605
|
values = _sanitize_value(self.visit(node.value))
|
|
@@ -605,12 +608,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
605
608
|
self.assignTarget(targets[0], values)
|
|
606
609
|
|
|
607
610
|
def visit_AugAssign(self, node):
|
|
608
|
-
|
|
609
|
-
lhs = ast.
|
|
611
|
+
lhs = copy.deepcopy(node.target)
|
|
612
|
+
lhs.ctx = ast.Load()
|
|
610
613
|
rhs = ast.BinOp(lhs, node.op, node.value)
|
|
611
614
|
assign = ast.Assign(targets=[node.target], value=rhs)
|
|
612
615
|
self.visit(assign)
|
|
613
|
-
return self.
|
|
616
|
+
return self.visit(lhs)
|
|
614
617
|
|
|
615
618
|
def visit_Name(self, node):
|
|
616
619
|
if type(node.ctx) is ast.Store:
|
|
@@ -630,10 +633,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
630
633
|
def _apply_binary_method(self, method_name, lhs, rhs):
|
|
631
634
|
# TODO: raise something meaningful if getattr fails below, esp for reverse method
|
|
632
635
|
if _is_triton_tensor(lhs):
|
|
633
|
-
return getattr(lhs, method_name)(rhs,
|
|
636
|
+
return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
|
|
634
637
|
if _is_triton_tensor(rhs):
|
|
635
638
|
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
|
|
636
|
-
return getattr(rhs, reverse_method_name)(lhs,
|
|
639
|
+
return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
|
|
640
|
+
if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
|
|
641
|
+
lhs = constexpr(lhs)
|
|
637
642
|
return getattr(lhs, method_name)(rhs)
|
|
638
643
|
|
|
639
644
|
def visit_BinOp(self, node):
|
|
@@ -786,7 +791,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
786
791
|
cond = self.visit(node.test)
|
|
787
792
|
|
|
788
793
|
if _is_triton_tensor(cond):
|
|
789
|
-
|
|
794
|
+
if _is_non_scalar_tensor(cond):
|
|
795
|
+
raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
|
|
796
|
+
if cond.type.is_block():
|
|
797
|
+
warnings.warn(
|
|
798
|
+
"If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
|
|
799
|
+
% ast.unparse(node.test))
|
|
800
|
+
cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
|
|
801
|
+
cond = cond.to(language.int1, _semantic=self.semantic)
|
|
790
802
|
contains_return = ContainsReturnChecker(self.gscope).visit(node)
|
|
791
803
|
if contains_return:
|
|
792
804
|
if self.scf_stack:
|
|
@@ -812,21 +824,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
812
824
|
def visit_IfExp(self, node):
|
|
813
825
|
cond = self.visit(node.test)
|
|
814
826
|
if _is_triton_tensor(cond):
|
|
815
|
-
cond = cond.to(language.int1,
|
|
827
|
+
cond = cond.to(language.int1, _semantic=self.semantic)
|
|
816
828
|
# TODO: Deal w/ more complicated return types (e.g tuple)
|
|
817
829
|
with enter_sub_region(self):
|
|
818
830
|
ip, last_loc = self._get_insertion_point_and_loc()
|
|
819
831
|
|
|
820
832
|
then_block = self.builder.create_block()
|
|
821
833
|
self.builder.set_insertion_point_to_start(then_block)
|
|
822
|
-
then_val = semantic.to_tensor(self.visit(node.body)
|
|
834
|
+
then_val = self.semantic.to_tensor(self.visit(node.body))
|
|
823
835
|
then_block = self.builder.get_insertion_block()
|
|
824
836
|
|
|
825
837
|
else_block = self.builder.create_block()
|
|
826
838
|
self.builder.set_insertion_point_to_start(else_block)
|
|
827
839
|
# do not need to reset lscope since
|
|
828
840
|
# ternary expressions cannot define new variables
|
|
829
|
-
else_val = semantic.to_tensor(self.visit(node.orelse)
|
|
841
|
+
else_val = self.semantic.to_tensor(self.visit(node.orelse))
|
|
830
842
|
else_block = self.builder.get_insertion_block()
|
|
831
843
|
|
|
832
844
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
@@ -892,10 +904,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
892
904
|
if fn is None:
|
|
893
905
|
raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
|
|
894
906
|
if _is_triton_tensor(operand):
|
|
895
|
-
return getattr(operand, fn)(
|
|
907
|
+
return getattr(operand, fn)(_semantic=self.semantic)
|
|
896
908
|
try:
|
|
897
909
|
return getattr(operand, fn)()
|
|
898
910
|
except AttributeError:
|
|
911
|
+
if fn == "__not__":
|
|
912
|
+
return constexpr(not operand)
|
|
899
913
|
raise self._unsupported(
|
|
900
914
|
node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
|
|
901
915
|
|
|
@@ -912,6 +926,20 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
912
926
|
f'but is re-assigned to {loop_val.type} in loop! '\
|
|
913
927
|
f'Please make sure that the type stays consistent.'
|
|
914
928
|
|
|
929
|
+
def visit_withitem(self, node):
|
|
930
|
+
return self.visit(node.context_expr)
|
|
931
|
+
|
|
932
|
+
def visit_With(self, node):
|
|
933
|
+
assert len(node.items) == 1
|
|
934
|
+
context = node.items[0].context_expr
|
|
935
|
+
withitemClass = self.visit(context.func)
|
|
936
|
+
if withitemClass == language.async_task:
|
|
937
|
+
args = [self.visit(arg) for arg in context.args]
|
|
938
|
+
with withitemClass(*args, _builder=self.builder):
|
|
939
|
+
self.visit_compound_statement(node.body)
|
|
940
|
+
else:
|
|
941
|
+
self.visit_compound_statement(node.body)
|
|
942
|
+
|
|
915
943
|
def visit_While(self, node):
|
|
916
944
|
with enter_sub_region(self) as sr:
|
|
917
945
|
liveins, insert_block = sr
|
|
@@ -991,15 +1019,15 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
991
1019
|
ast.NodeVisitor.generic_visit(self, stmt)
|
|
992
1020
|
|
|
993
1021
|
def visit_Subscript_Load(self, node):
|
|
994
|
-
assert node.ctx.
|
|
1022
|
+
assert isinstance(node.ctx, ast.Load)
|
|
995
1023
|
lhs = self.visit(node.value)
|
|
996
1024
|
slices = self.visit(node.slice)
|
|
997
1025
|
if _is_triton_tensor(lhs):
|
|
998
|
-
return lhs.__getitem__(slices,
|
|
1026
|
+
return lhs.__getitem__(slices, _semantic=self.semantic)
|
|
999
1027
|
return lhs[slices]
|
|
1000
1028
|
|
|
1001
1029
|
def visit_Subscript_Store(self, node, value):
|
|
1002
|
-
assert node.ctx.
|
|
1030
|
+
assert isinstance(node.ctx, ast.Store)
|
|
1003
1031
|
lhs = self.visit(node.value)
|
|
1004
1032
|
slices = self.visit(node.slice)
|
|
1005
1033
|
assert isinstance(lhs, language.tuple)
|
|
@@ -1028,6 +1056,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1028
1056
|
loop_unroll_factor = None
|
|
1029
1057
|
disallow_acc_multi_buffer = False
|
|
1030
1058
|
flatten = False
|
|
1059
|
+
warp_specialize = False
|
|
1031
1060
|
if IteratorClass is language.range:
|
|
1032
1061
|
iterator = IteratorClass(*iter_args, **iter_kwargs)
|
|
1033
1062
|
# visit iterator arguments
|
|
@@ -1040,6 +1069,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1040
1069
|
loop_unroll_factor = iterator.loop_unroll_factor
|
|
1041
1070
|
disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
|
|
1042
1071
|
flatten = iterator.flatten
|
|
1072
|
+
warp_specialize = iterator.warp_specialize
|
|
1043
1073
|
elif IteratorClass is range:
|
|
1044
1074
|
# visit iterator arguments
|
|
1045
1075
|
# note: only `range` iterator is supported now
|
|
@@ -1055,14 +1085,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1055
1085
|
step = constexpr(-step.value)
|
|
1056
1086
|
negative_step = True
|
|
1057
1087
|
lb, ub = ub, lb
|
|
1058
|
-
lb = semantic.to_tensor(lb
|
|
1059
|
-
ub = semantic.to_tensor(ub
|
|
1060
|
-
step = semantic.to_tensor(step
|
|
1088
|
+
lb = self.semantic.to_tensor(lb)
|
|
1089
|
+
ub = self.semantic.to_tensor(ub)
|
|
1090
|
+
step = self.semantic.to_tensor(step)
|
|
1061
1091
|
# induction variable type
|
|
1062
1092
|
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
|
|
1063
1093
|
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
|
|
1064
|
-
iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
|
1065
|
-
iv_type = semantic.integer_promote_impl(iv_type, step.dtype)
|
|
1094
|
+
iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
|
1095
|
+
iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
|
|
1066
1096
|
iv_ir_type = iv_type.to_ir(self.builder)
|
|
1067
1097
|
iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
|
|
1068
1098
|
# lb/ub/step might be constexpr, we need to cast them to tensor
|
|
@@ -1118,6 +1148,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1118
1148
|
for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
|
|
1119
1149
|
if flatten:
|
|
1120
1150
|
for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
|
|
1151
|
+
if warp_specialize:
|
|
1152
|
+
for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
|
|
1121
1153
|
|
|
1122
1154
|
self.scf_stack.append(node)
|
|
1123
1155
|
for_op_body = for_op.get_body(0)
|
|
@@ -1136,7 +1168,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1136
1168
|
if name in liveins:
|
|
1137
1169
|
local = self.local_defs[name]
|
|
1138
1170
|
if isinstance(local, constexpr):
|
|
1139
|
-
local = semantic.to_tensor(local
|
|
1171
|
+
local = self.semantic.to_tensor(local)
|
|
1140
1172
|
yields.append(local)
|
|
1141
1173
|
|
|
1142
1174
|
# create YieldOp
|
|
@@ -1180,7 +1212,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1180
1212
|
def visit_Assert(self, node) -> Any:
|
|
1181
1213
|
test = self.visit(node.test)
|
|
1182
1214
|
msg = self.visit(node.msg) if node.msg is not None else ""
|
|
1183
|
-
return language.core.device_assert(test, msg,
|
|
1215
|
+
return language.core.device_assert(test, msg, _semantic=self.semantic)
|
|
1184
1216
|
|
|
1185
1217
|
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
|
1186
1218
|
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
|
@@ -1193,10 +1225,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1193
1225
|
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
|
|
1194
1226
|
args_val = [get_iterable_path(args, path) for path in args_path]
|
|
1195
1227
|
# mangle
|
|
1196
|
-
fn_name = mangle_fn(fn
|
|
1228
|
+
fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst)
|
|
1197
1229
|
# generate function def if necessary
|
|
1198
1230
|
if not self.module.has_function(fn_name):
|
|
1199
|
-
gscope = fn.__globals__
|
|
1200
1231
|
# If the callee is not set, we use the same debug setting as the caller
|
|
1201
1232
|
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1202
1233
|
arg_types = [
|
|
@@ -1205,7 +1236,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1205
1236
|
for arg in args
|
|
1206
1237
|
]
|
|
1207
1238
|
prototype = ASTFunction([], arg_types, args_cst, dict())
|
|
1208
|
-
generator = CodeGenerator(self.context, prototype,
|
|
1239
|
+
generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
|
|
1209
1240
|
function_name=fn_name, function_types=self.function_ret_types,
|
|
1210
1241
|
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
|
1211
1242
|
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
|
|
@@ -1214,6 +1245,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1214
1245
|
generator.visit(fn.parse())
|
|
1215
1246
|
except Exception as e:
|
|
1216
1247
|
# Wrap the error in the callee with the location of the call.
|
|
1248
|
+
if knobs.compilation.front_end_debugging:
|
|
1249
|
+
raise
|
|
1217
1250
|
raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
|
|
1218
1251
|
|
|
1219
1252
|
callee_ret_type = generator.ret_type
|
|
@@ -1221,7 +1254,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1221
1254
|
else:
|
|
1222
1255
|
callee_ret_type = self.function_ret_types[fn_name]
|
|
1223
1256
|
symbol = self.module.get_function(fn_name)
|
|
1224
|
-
args_val =
|
|
1257
|
+
args_val = flatten_values_to_ir(args_val)
|
|
1225
1258
|
call_op = self.builder.call(symbol, args_val)
|
|
1226
1259
|
if callee_ret_type == language.void:
|
|
1227
1260
|
return None
|
|
@@ -1230,18 +1263,29 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1230
1263
|
|
|
1231
1264
|
def visit_Call(self, node):
|
|
1232
1265
|
fn = _unwrap_if_constexpr(self.visit(node.func))
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1266
|
+
if not isinstance(fn, BoundJITMethod):
|
|
1267
|
+
static_implementation = self.statically_implemented_functions.get(fn)
|
|
1268
|
+
if static_implementation is not None:
|
|
1269
|
+
return static_implementation(self, node)
|
|
1270
|
+
|
|
1271
|
+
mur = getattr(fn, '_must_use_result', False)
|
|
1272
|
+
if mur and getattr(node, '_is_unused', False):
|
|
1273
|
+
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
|
|
1274
|
+
if isinstance(mur, str):
|
|
1275
|
+
error_message.append(mur)
|
|
1276
|
+
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
|
|
1236
1277
|
|
|
1237
1278
|
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
|
1238
1279
|
args = [self.visit(arg) for arg in node.args]
|
|
1239
1280
|
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
|
|
1281
|
+
if isinstance(fn, BoundJITMethod):
|
|
1282
|
+
args.insert(0, fn.__self__)
|
|
1283
|
+
fn = fn.__func__
|
|
1240
1284
|
if isinstance(fn, JITFunction):
|
|
1241
1285
|
_check_fn_args(node, fn, args)
|
|
1242
1286
|
return self.call_JitFunction(fn, args, kws)
|
|
1243
1287
|
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
|
|
1244
|
-
extra_kwargs = {"
|
|
1288
|
+
extra_kwargs = {"_semantic": self.semantic}
|
|
1245
1289
|
sig = inspect.signature(fn)
|
|
1246
1290
|
if '_generator' in sig.parameters:
|
|
1247
1291
|
extra_kwargs['_generator'] = self
|
|
@@ -1252,6 +1296,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1252
1296
|
ret = language.tuple(ret)
|
|
1253
1297
|
return ret
|
|
1254
1298
|
except Exception as e:
|
|
1299
|
+
if knobs.compilation.front_end_debugging:
|
|
1300
|
+
raise
|
|
1255
1301
|
# Normally when we raise a CompilationError, we raise it as
|
|
1256
1302
|
# `from None`, because the original fileline from the exception
|
|
1257
1303
|
# is not relevant (and often points into code_generator.py
|
|
@@ -1269,26 +1315,73 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1269
1315
|
return constexpr(node.value)
|
|
1270
1316
|
|
|
1271
1317
|
def visit_BoolOp(self, node: ast.BoolOp):
|
|
1272
|
-
if len(node.values) != 2:
|
|
1273
|
-
raise self._unsupported(
|
|
1274
|
-
node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
|
|
1275
|
-
lhs = self.visit(node.values[0])
|
|
1276
|
-
rhs = self.visit(node.values[1])
|
|
1277
1318
|
method_name = self._method_name_for_bool_op.get(type(node.op))
|
|
1278
1319
|
if method_name is None:
|
|
1279
1320
|
raise self._unsupported(
|
|
1280
1321
|
node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
|
|
1281
|
-
|
|
1322
|
+
|
|
1323
|
+
nontrivial_values = []
|
|
1324
|
+
|
|
1325
|
+
for subnode in node.values:
|
|
1326
|
+
# we visit the values in order, executing their side-effects
|
|
1327
|
+
# and possibly early-exiting:
|
|
1328
|
+
value = self.visit(subnode)
|
|
1329
|
+
if not _is_triton_tensor(value):
|
|
1330
|
+
# this is a constexpr, so we might be able to short-circuit:
|
|
1331
|
+
bv = bool(value)
|
|
1332
|
+
if (bv is False) and (method_name == "logical_and"):
|
|
1333
|
+
# value is falsey so return that:
|
|
1334
|
+
return value
|
|
1335
|
+
if (bv is True) and (method_name == "logical_or"):
|
|
1336
|
+
# value is truthy so return that:
|
|
1337
|
+
return value
|
|
1338
|
+
# otherwise, our constexpr has no effect on the output of the
|
|
1339
|
+
# expression so we do not append it to nontrivial_values.
|
|
1340
|
+
else:
|
|
1341
|
+
if value.type.is_block():
|
|
1342
|
+
lineno = getattr(node, "lineno", None)
|
|
1343
|
+
if lineno is not None:
|
|
1344
|
+
lineno += self.begin_line
|
|
1345
|
+
warnings.warn_explicit(
|
|
1346
|
+
"Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
|
|
1347
|
+
category=UserWarning,
|
|
1348
|
+
filename=self.file_name,
|
|
1349
|
+
lineno=lineno,
|
|
1350
|
+
source=ast.unparse(node),
|
|
1351
|
+
)
|
|
1352
|
+
# not a constexpr so we must append it:
|
|
1353
|
+
nontrivial_values.append(value)
|
|
1354
|
+
|
|
1355
|
+
if len(nontrivial_values) == 0:
|
|
1356
|
+
# the semantics of a disjunction of falsey values or conjunction
|
|
1357
|
+
# of truthy values is to return the final value:
|
|
1358
|
+
nontrivial_values.append(value)
|
|
1359
|
+
|
|
1360
|
+
while len(nontrivial_values) >= 2:
|
|
1361
|
+
rhs = nontrivial_values.pop()
|
|
1362
|
+
lhs = nontrivial_values.pop()
|
|
1363
|
+
res = self._apply_binary_method(method_name, lhs, rhs)
|
|
1364
|
+
nontrivial_values.append(res)
|
|
1365
|
+
|
|
1366
|
+
assert len(nontrivial_values) == 1
|
|
1367
|
+
return nontrivial_values[0]
|
|
1282
1368
|
|
|
1283
1369
|
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
|
1284
1370
|
|
|
1285
1371
|
def visit_Attribute(self, node):
|
|
1286
1372
|
lhs = self.visit(node.value)
|
|
1287
1373
|
if _is_triton_tensor(lhs) and node.attr == "T":
|
|
1288
|
-
return semantic.permute(lhs, (1, 0)
|
|
1289
|
-
|
|
1374
|
+
return self.semantic.permute(lhs, (1, 0))
|
|
1375
|
+
# NOTE: special case ".value" for BC
|
|
1376
|
+
if isinstance(lhs, constexpr) and node.attr != "value":
|
|
1377
|
+
lhs = lhs.value
|
|
1378
|
+
attr = getattr(lhs, node.attr)
|
|
1379
|
+
if _is_triton_value(lhs) and isinstance(attr, JITFunction):
|
|
1380
|
+
return BoundJITMethod(lhs, attr)
|
|
1381
|
+
return attr
|
|
1290
1382
|
|
|
1291
1383
|
def visit_Expr(self, node):
|
|
1384
|
+
node.value._is_unused = True
|
|
1292
1385
|
ast.NodeVisitor.generic_visit(self, node)
|
|
1293
1386
|
|
|
1294
1387
|
def visit_NoneType(self, node):
|
|
@@ -1331,6 +1424,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1331
1424
|
except CompilationError:
|
|
1332
1425
|
raise
|
|
1333
1426
|
except Exception as e:
|
|
1427
|
+
if knobs.compilation.front_end_debugging:
|
|
1428
|
+
raise
|
|
1334
1429
|
# Wrap the error in a CompilationError which contains the source
|
|
1335
1430
|
# of the @jit function.
|
|
1336
1431
|
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
|
|
@@ -1378,16 +1473,22 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1378
1473
|
|
|
1379
1474
|
return ret
|
|
1380
1475
|
|
|
1476
|
+
from ..experimental.gluon import language as ttgl
|
|
1381
1477
|
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
|
|
1382
1478
|
language.core.static_assert: execute_static_assert,
|
|
1383
1479
|
language.core.static_print: static_executor(print),
|
|
1480
|
+
ttgl.static_assert: execute_static_assert,
|
|
1481
|
+
ttgl.static_print: static_executor(print),
|
|
1384
1482
|
int: static_executor(int),
|
|
1385
1483
|
len: static_executor(len),
|
|
1386
1484
|
}
|
|
1387
1485
|
|
|
1388
1486
|
|
|
1389
|
-
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
|
|
1390
|
-
arg_types =
|
|
1487
|
+
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
|
|
1488
|
+
arg_types = [None] * len(fn.arg_names)
|
|
1489
|
+
for k, v in src.signature.items():
|
|
1490
|
+
idx = fn.arg_names.index(k)
|
|
1491
|
+
arg_types[idx] = str_to_ty(v)
|
|
1391
1492
|
prototype = ASTFunction([], arg_types, src.constants, src.attrs)
|
|
1392
1493
|
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1393
1494
|
# query function representation
|
|
@@ -1396,9 +1497,9 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
|
|
|
1396
1497
|
constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
|
|
1397
1498
|
signature = src.signature
|
|
1398
1499
|
proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
|
|
1399
|
-
generator = CodeGenerator(context, prototype, gscope=fn.
|
|
1400
|
-
is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
|
|
1401
|
-
codegen_fns=codegen_fns, module_map=module_map)
|
|
1500
|
+
generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
|
|
1501
|
+
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
|
|
1502
|
+
codegen_fns=codegen_fns, module_map=module_map, module=module)
|
|
1402
1503
|
generator.visit(fn.parse())
|
|
1403
1504
|
ret = generator.module
|
|
1404
1505
|
# module takes ownership of the context
|