triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import builtins
|
|
3
|
+
import contextlib
|
|
2
4
|
import copy
|
|
3
5
|
import inspect
|
|
4
6
|
import re
|
|
@@ -11,11 +13,10 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable,
|
|
|
11
13
|
|
|
12
14
|
from .. import knobs, language
|
|
13
15
|
from .._C.libtriton import ir, gluon_ir
|
|
14
|
-
from ..language import constexpr, str_to_ty, tensor
|
|
16
|
+
from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
|
|
15
17
|
from ..language.core import _unwrap_if_constexpr, base_value, base_type
|
|
16
|
-
from ..runtime.jit import get_jit_fn_file_line, get_full_name
|
|
17
18
|
# ideally we wouldn't need any runtime component
|
|
18
|
-
from ..runtime import JITFunction
|
|
19
|
+
from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
|
|
19
20
|
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
|
|
20
21
|
|
|
21
22
|
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
|
@@ -28,7 +29,7 @@ def check_identifier_legality(name, type):
|
|
|
28
29
|
return name
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def mangle_fn(name, arg_tys, constants):
|
|
32
|
+
def mangle_fn(name, arg_tys, constants, caller_context):
|
|
32
33
|
# doesn't mangle ret type, which must be a function of arg tys
|
|
33
34
|
mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
|
|
34
35
|
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
|
|
@@ -37,6 +38,8 @@ def mangle_fn(name, arg_tys, constants):
|
|
|
37
38
|
# [ and ] are not allowed in LLVM identifiers
|
|
38
39
|
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
|
|
39
40
|
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
|
41
|
+
if caller_context is not None:
|
|
42
|
+
ret += caller_context.mangle()
|
|
40
43
|
return ret
|
|
41
44
|
|
|
42
45
|
|
|
@@ -49,7 +52,7 @@ def _is_triton_tensor(o: Any) -> bool:
|
|
|
49
52
|
|
|
50
53
|
|
|
51
54
|
def _is_constexpr(o: Any) -> bool:
|
|
52
|
-
return o is None or isinstance(o, (constexpr, language.core.dtype,
|
|
55
|
+
return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
def _is_non_scalar_tensor(o: Any) -> bool:
|
|
@@ -106,6 +109,17 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
|
|
|
106
109
|
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
|
|
107
110
|
|
|
108
111
|
|
|
112
|
+
def _clone_triton_value(val):
|
|
113
|
+
handles = []
|
|
114
|
+
val._flatten_ir(handles)
|
|
115
|
+
clone, _ = val.type._unflatten_ir(handles, 0)
|
|
116
|
+
return clone
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _clone_scope(scope):
|
|
120
|
+
return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
|
|
121
|
+
|
|
122
|
+
|
|
109
123
|
class enter_sub_region:
|
|
110
124
|
|
|
111
125
|
def __init__(self, generator):
|
|
@@ -113,8 +127,8 @@ class enter_sub_region:
|
|
|
113
127
|
|
|
114
128
|
def __enter__(self):
|
|
115
129
|
# record lscope & local_defs in the parent scope
|
|
116
|
-
self.liveins = self.generator.lscope
|
|
117
|
-
self.prev_defs = self.generator.local_defs
|
|
130
|
+
self.liveins = _clone_scope(self.generator.lscope)
|
|
131
|
+
self.prev_defs = _clone_scope(self.generator.local_defs)
|
|
118
132
|
self.generator.local_defs = {}
|
|
119
133
|
self.insert_block = self.generator.builder.get_insertion_block()
|
|
120
134
|
self.insert_point = self.generator.builder.get_insertion_point()
|
|
@@ -136,9 +150,9 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
|
|
136
150
|
return any(self.visit(s) for s in body)
|
|
137
151
|
|
|
138
152
|
def _visit_function(self, fn) -> bool:
|
|
139
|
-
#
|
|
140
|
-
# If the function itself has unstructured control flow we may not be able to inline it causing poor performance
|
|
141
|
-
#
|
|
153
|
+
# No need to check within the function as it won't cause an early return.
|
|
154
|
+
# If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
|
|
155
|
+
# we should check for this and emit a warning.
|
|
142
156
|
return False
|
|
143
157
|
|
|
144
158
|
def generic_visit(self, node) -> bool:
|
|
@@ -280,11 +294,12 @@ class BoundJITMethod:
|
|
|
280
294
|
|
|
281
295
|
class CodeGenerator(ast.NodeVisitor):
|
|
282
296
|
|
|
283
|
-
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns,
|
|
284
|
-
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
|
285
|
-
file_name: Optional[str] = None, begin_line=0):
|
|
297
|
+
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
|
|
298
|
+
module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
|
299
|
+
noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
|
|
286
300
|
self.context = context
|
|
287
|
-
|
|
301
|
+
self.is_gluon = is_gluon
|
|
302
|
+
if is_gluon:
|
|
288
303
|
from triton.experimental.gluon.language._semantic import GluonSemantic
|
|
289
304
|
self.builder = gluon_ir.GluonOpBuilder(context)
|
|
290
305
|
self.semantic = GluonSemantic(self.builder)
|
|
@@ -292,6 +307,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
292
307
|
from triton.language.semantic import TritonSemantic
|
|
293
308
|
self.builder = ir.builder(context)
|
|
294
309
|
self.semantic = TritonSemantic(self.builder)
|
|
310
|
+
|
|
311
|
+
self.name_loc_as_prefix = None
|
|
295
312
|
self.file_name = file_name
|
|
296
313
|
# node.lineno starts from 1, so we need to subtract 1
|
|
297
314
|
self.begin_line = begin_line - 1
|
|
@@ -328,6 +345,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
328
345
|
self.is_kernel = is_kernel
|
|
329
346
|
self.cur_node = None
|
|
330
347
|
self.noinline = noinline
|
|
348
|
+
self.caller_context = caller_context
|
|
331
349
|
self.scf_stack = []
|
|
332
350
|
self.ret_type = None
|
|
333
351
|
# SSA-construction
|
|
@@ -378,7 +396,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
378
396
|
val is absent,
|
|
379
397
|
name in self.builtin_namespace, #
|
|
380
398
|
type(val) is ModuleType, #
|
|
381
|
-
isinstance(val,
|
|
399
|
+
isinstance(val, JITCallable), #
|
|
382
400
|
getattr(val, "__triton_builtin__", False), #
|
|
383
401
|
getattr(val, "__triton_aggregate__", False), #
|
|
384
402
|
getattr(val, "__module__", "").startswith("triton.language"), #
|
|
@@ -414,6 +432,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
414
432
|
|
|
415
433
|
return name_lookup
|
|
416
434
|
|
|
435
|
+
@contextlib.contextmanager
|
|
436
|
+
def _name_loc_prefix(self, prefix):
|
|
437
|
+
self.name_loc_as_prefix = prefix
|
|
438
|
+
yield
|
|
439
|
+
self.name_loc_as_prefix = None
|
|
440
|
+
|
|
441
|
+
def _maybe_set_loc_to_name(self, val, name):
|
|
442
|
+
if isinstance(val, (ir.value, ir.block_argument)):
|
|
443
|
+
val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
|
|
444
|
+
elif _is_triton_value(val):
|
|
445
|
+
handles = []
|
|
446
|
+
val._flatten_ir(handles)
|
|
447
|
+
for handle in handles:
|
|
448
|
+
handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
|
|
449
|
+
|
|
417
450
|
def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
|
|
418
451
|
''' This function:
|
|
419
452
|
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
|
@@ -435,6 +468,43 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
435
468
|
self.builder.restore_insertion_point(ip)
|
|
436
469
|
self.builder.set_loc(loc)
|
|
437
470
|
|
|
471
|
+
def _find_carries(self, node, liveins):
|
|
472
|
+
# create loop body block
|
|
473
|
+
block = self.builder.create_block()
|
|
474
|
+
self.builder.set_insertion_point_to_start(block)
|
|
475
|
+
# dry visit loop body
|
|
476
|
+
self.scf_stack.append(node)
|
|
477
|
+
self.visit_compound_statement(node.body)
|
|
478
|
+
self.scf_stack.pop()
|
|
479
|
+
block.erase()
|
|
480
|
+
|
|
481
|
+
# If a variable (name) has changed value within the loop, then it's
|
|
482
|
+
# a loop-carried variable. (The new and old value must be of the
|
|
483
|
+
# same type)
|
|
484
|
+
init_tys = []
|
|
485
|
+
init_handles = []
|
|
486
|
+
names = []
|
|
487
|
+
|
|
488
|
+
for name, live_val in liveins.items():
|
|
489
|
+
if _is_triton_value(live_val):
|
|
490
|
+
loop_val = self.lscope[name]
|
|
491
|
+
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
492
|
+
|
|
493
|
+
live_handles = flatten_values_to_ir([live_val])
|
|
494
|
+
loop_handles = flatten_values_to_ir([loop_val])
|
|
495
|
+
if live_handles != loop_handles:
|
|
496
|
+
names.append(name)
|
|
497
|
+
init_tys.append(live_val.type)
|
|
498
|
+
init_handles.extend(live_handles)
|
|
499
|
+
else:
|
|
500
|
+
assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
|
|
501
|
+
|
|
502
|
+
# reset local scope to not pick up local defs from the dry run.
|
|
503
|
+
self.lscope = liveins.copy()
|
|
504
|
+
self.local_defs = {}
|
|
505
|
+
|
|
506
|
+
return names, init_handles, init_tys
|
|
507
|
+
|
|
438
508
|
#
|
|
439
509
|
# AST visitor
|
|
440
510
|
#
|
|
@@ -458,6 +528,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
458
528
|
elts = language.tuple([self.visit(elt) for elt in node.elts])
|
|
459
529
|
return elts
|
|
460
530
|
|
|
531
|
+
def visit_ListComp(self, node: ast.ListComp):
|
|
532
|
+
if len(node.generators) != 1:
|
|
533
|
+
raise ValueError("nested comprehensions are not supported")
|
|
534
|
+
|
|
535
|
+
comp = node.generators[0]
|
|
536
|
+
iter = self.visit(comp.iter)
|
|
537
|
+
if not isinstance(iter, tl_tuple):
|
|
538
|
+
raise NotImplementedError("only tuple comprehensions are supported")
|
|
539
|
+
|
|
540
|
+
results = []
|
|
541
|
+
for item in iter:
|
|
542
|
+
self.set_value(comp.target.id, item)
|
|
543
|
+
results.append(self.visit(node.elt))
|
|
544
|
+
return tl_tuple(results)
|
|
545
|
+
|
|
461
546
|
# By design, only non-kernel functions can return
|
|
462
547
|
def visit_Return(self, node):
|
|
463
548
|
ret_value = self.visit(node.value)
|
|
@@ -522,8 +607,11 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
522
607
|
self.module.push_back(self.fn)
|
|
523
608
|
entry = self.fn.add_entry_block()
|
|
524
609
|
arg_values = self.prototype.deserialize(self.fn)
|
|
610
|
+
if self.caller_context is not None:
|
|
611
|
+
self.caller_context.initialize_callee(self.fn, self.builder)
|
|
525
612
|
# bind arguments to symbols
|
|
526
613
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
|
614
|
+
self._maybe_set_loc_to_name(arg_value, arg_name)
|
|
527
615
|
self.set_value(arg_name, arg_value)
|
|
528
616
|
insert_pt = self.builder.get_insertion_block()
|
|
529
617
|
self.builder.set_insertion_point_to_start(entry)
|
|
@@ -583,9 +671,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
583
671
|
self.assignTarget(target, value.values[i])
|
|
584
672
|
return
|
|
585
673
|
if isinstance(target, ast.Attribute):
|
|
586
|
-
|
|
587
|
-
setattr(base, target.attr, value)
|
|
588
|
-
return
|
|
674
|
+
raise NotImplementedError("Attribute assignment is not supported in triton")
|
|
589
675
|
assert isinstance(target, ast.Name)
|
|
590
676
|
self.set_value(self.visit(target), value)
|
|
591
677
|
|
|
@@ -602,10 +688,15 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
602
688
|
value = self.semantic.to_tensor(value)
|
|
603
689
|
return value
|
|
604
690
|
|
|
605
|
-
values = _sanitize_value(self.visit(node.value))
|
|
606
691
|
targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
|
|
607
692
|
assert len(targets) == 1
|
|
608
|
-
|
|
693
|
+
target = targets[0]
|
|
694
|
+
if isinstance(target, ast.Name):
|
|
695
|
+
with self._name_loc_prefix(target.id):
|
|
696
|
+
values = _sanitize_value(self.visit(node.value))
|
|
697
|
+
else:
|
|
698
|
+
values = _sanitize_value(self.visit(node.value))
|
|
699
|
+
self.assignTarget(target, values)
|
|
609
700
|
|
|
610
701
|
def visit_AugAssign(self, node):
|
|
611
702
|
lhs = copy.deepcopy(node.target)
|
|
@@ -671,8 +762,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
671
762
|
self.visit_compound_statement(node.body)
|
|
672
763
|
then_block = self.builder.get_insertion_block()
|
|
673
764
|
then_defs = self.local_defs.copy()
|
|
765
|
+
then_vals = self.lscope.copy()
|
|
674
766
|
# else block
|
|
675
767
|
else_defs = {}
|
|
768
|
+
else_vals = liveins.copy()
|
|
676
769
|
if node.orelse:
|
|
677
770
|
self.builder.set_insertion_point_to_start(else_block)
|
|
678
771
|
self.lscope = liveins.copy()
|
|
@@ -680,26 +773,29 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
680
773
|
self.visit_compound_statement(node.orelse)
|
|
681
774
|
else_defs = self.local_defs.copy()
|
|
682
775
|
else_block = self.builder.get_insertion_block()
|
|
776
|
+
else_vals = self.lscope.copy()
|
|
683
777
|
|
|
684
778
|
# update block arguments
|
|
685
779
|
names = []
|
|
686
780
|
# variables in livein whose value is updated in `if`
|
|
687
|
-
for name in liveins:
|
|
781
|
+
for name, value in liveins.items():
|
|
782
|
+
# livein variable changed value in either then or else
|
|
783
|
+
if not _is_triton_value(value):
|
|
784
|
+
continue
|
|
785
|
+
then_handles = flatten_values_to_ir([then_vals[name]])
|
|
786
|
+
else_handles = flatten_values_to_ir([else_vals[name]])
|
|
787
|
+
if then_handles == else_handles:
|
|
788
|
+
continue
|
|
789
|
+
names.append(name)
|
|
790
|
+
then_defs[name] = then_vals[name]
|
|
791
|
+
else_defs[name] = else_vals[name]
|
|
688
792
|
# check type
|
|
689
793
|
for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
if name in then_defs or name in else_defs:
|
|
696
|
-
names.append(name)
|
|
697
|
-
# variable defined in then but not in else
|
|
698
|
-
if name in then_defs and name not in else_defs:
|
|
699
|
-
else_defs[name] = liveins[name]
|
|
700
|
-
# variable defined in else but not in then
|
|
701
|
-
if name in else_defs and name not in then_defs:
|
|
702
|
-
then_defs[name] = liveins[name]
|
|
794
|
+
type_equal = type(defs[name]) == type(value) # noqa: E721
|
|
795
|
+
assert type_equal and defs[name].type == value.type, \
|
|
796
|
+
f'initial value for `{name}` is of type {value}, '\
|
|
797
|
+
f'but the {block_name} block redefines it as {defs[name]}'
|
|
798
|
+
|
|
703
799
|
# variables that are both in then and else but not in liveins
|
|
704
800
|
# TODO: could probably be cleaned up
|
|
705
801
|
for name in sorted(then_defs.keys() & else_defs.keys()):
|
|
@@ -766,6 +862,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
766
862
|
self.visit_then_else_blocks(node, liveins, then_block, else_block)
|
|
767
863
|
# create if op
|
|
768
864
|
then_handles = flatten_values_to_ir(then_defs[name] for name in names)
|
|
865
|
+
for name, val in zip(names, then_handles):
|
|
866
|
+
self._maybe_set_loc_to_name(val, name)
|
|
769
867
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
770
868
|
if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
|
|
771
869
|
then_block.merge_block_before(if_op.get_then_block())
|
|
@@ -779,6 +877,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
779
877
|
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
|
780
878
|
if len(names) > 0:
|
|
781
879
|
else_handles = flatten_values_to_ir(else_defs[name] for name in names)
|
|
880
|
+
for name, val in zip(names, else_handles):
|
|
881
|
+
self._maybe_set_loc_to_name(val, name)
|
|
782
882
|
self.builder.create_yield_op(else_handles)
|
|
783
883
|
# update values
|
|
784
884
|
res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
|
|
@@ -799,13 +899,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
799
899
|
% ast.unparse(node.test))
|
|
800
900
|
cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
|
|
801
901
|
cond = cond.to(language.int1, _semantic=self.semantic)
|
|
802
|
-
|
|
803
|
-
if contains_return:
|
|
902
|
+
if ContainsReturnChecker(self.gscope).visit(node):
|
|
804
903
|
if self.scf_stack:
|
|
805
904
|
raise self._unsupported(
|
|
806
|
-
node, "Cannot have `return` statements inside `while` or `for` statements in triton
|
|
807
|
-
"(note that this also applies to `return` statements that are inside functions "
|
|
808
|
-
"transitively called from within `while`/`for` statements)")
|
|
905
|
+
node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
|
|
809
906
|
self.visit_if_top_level(cond, node)
|
|
810
907
|
else:
|
|
811
908
|
self.visit_if_scf(cond, node)
|
|
@@ -874,6 +971,37 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
874
971
|
else:
|
|
875
972
|
return self.visit(node.orelse)
|
|
876
973
|
|
|
974
|
+
def visit_With(self, node):
|
|
975
|
+
# Lower `with` statements by constructing context managers and calling their enter/exit hooks
|
|
976
|
+
# Instantiate each context manager with builder injection
|
|
977
|
+
if len(node.items) == 1: # Handle async_task
|
|
978
|
+
context = node.items[0].context_expr
|
|
979
|
+
withitemClass = self.visit(context.func)
|
|
980
|
+
if withitemClass == language.async_task:
|
|
981
|
+
args = [self.visit(arg) for arg in context.args]
|
|
982
|
+
with withitemClass(*args, _builder=self.builder):
|
|
983
|
+
self.visit_compound_statement(node.body)
|
|
984
|
+
return
|
|
985
|
+
|
|
986
|
+
cm_list = []
|
|
987
|
+
for item in node.items:
|
|
988
|
+
call = item.context_expr
|
|
989
|
+
fn = self.visit(call.func)
|
|
990
|
+
args = [self.visit(arg) for arg in call.args]
|
|
991
|
+
kws = dict(self.visit(kw) for kw in call.keywords)
|
|
992
|
+
cm = fn(*args, _semantic=self.semantic, **kws)
|
|
993
|
+
cm_list.append(cm)
|
|
994
|
+
for cm, item in zip(cm_list, node.items):
|
|
995
|
+
res = cm.__enter__()
|
|
996
|
+
if item.optional_vars is not None:
|
|
997
|
+
var_name = self.visit(item.optional_vars)
|
|
998
|
+
self.set_value(var_name, res)
|
|
999
|
+
if ContainsReturnChecker(self.gscope).visit(node):
|
|
1000
|
+
raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
|
|
1001
|
+
self.visit_compound_statement(node.body)
|
|
1002
|
+
for cm in reversed(cm_list):
|
|
1003
|
+
cm.__exit__(None, None, None)
|
|
1004
|
+
|
|
877
1005
|
def visit_Pass(self, node):
|
|
878
1006
|
pass
|
|
879
1007
|
|
|
@@ -918,9 +1046,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
918
1046
|
}
|
|
919
1047
|
|
|
920
1048
|
def _verify_loop_carried_variable(self, name, loop_val, live_val):
|
|
921
|
-
assert _is_triton_value(loop_val), f'cannot reassign
|
|
922
|
-
assert _is_triton_value(live_val), f'cannot
|
|
923
|
-
assert type(loop_val) is type(live_val),
|
|
1049
|
+
assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
|
|
1050
|
+
assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
|
|
1051
|
+
assert type(loop_val) is type(live_val), (
|
|
1052
|
+
f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
|
|
924
1053
|
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
|
|
925
1054
|
f'Loop-carried variable {name} has initial type {live_val.type} '\
|
|
926
1055
|
f'but is re-assigned to {loop_val.type} in loop! '\
|
|
@@ -929,49 +1058,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
929
1058
|
def visit_withitem(self, node):
|
|
930
1059
|
return self.visit(node.context_expr)
|
|
931
1060
|
|
|
932
|
-
def visit_With(self, node):
|
|
933
|
-
assert len(node.items) == 1
|
|
934
|
-
context = node.items[0].context_expr
|
|
935
|
-
withitemClass = self.visit(context.func)
|
|
936
|
-
if withitemClass == language.async_task:
|
|
937
|
-
args = [self.visit(arg) for arg in context.args]
|
|
938
|
-
with withitemClass(*args, _builder=self.builder):
|
|
939
|
-
self.visit_compound_statement(node.body)
|
|
940
|
-
else:
|
|
941
|
-
self.visit_compound_statement(node.body)
|
|
942
|
-
|
|
943
1061
|
def visit_While(self, node):
|
|
944
1062
|
with enter_sub_region(self) as sr:
|
|
945
1063
|
liveins, insert_block = sr
|
|
946
1064
|
ip, last_loc = self._get_insertion_point_and_loc()
|
|
947
1065
|
|
|
948
|
-
|
|
949
|
-
# loop_block = self.builder.create_block()
|
|
950
|
-
dummy = self.builder.create_block()
|
|
951
|
-
self.builder.set_insertion_point_to_start(dummy)
|
|
952
|
-
self.scf_stack.append(node)
|
|
953
|
-
self.visit_compound_statement(node.body)
|
|
954
|
-
self.scf_stack.pop()
|
|
955
|
-
loop_defs = self.local_defs
|
|
956
|
-
dummy.erase()
|
|
957
|
-
|
|
958
|
-
# collect loop-carried values
|
|
959
|
-
names = []
|
|
960
|
-
init_args = []
|
|
961
|
-
for name in loop_defs:
|
|
962
|
-
if name in liveins:
|
|
963
|
-
# We should not def new constexpr
|
|
964
|
-
loop_val = loop_defs[name]
|
|
965
|
-
live_val = liveins[name]
|
|
966
|
-
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
967
|
-
|
|
968
|
-
# these are loop-carried values
|
|
969
|
-
names.append(name)
|
|
970
|
-
init_args.append(live_val)
|
|
1066
|
+
names, init_handles, init_fe_tys = self._find_carries(node, liveins)
|
|
971
1067
|
|
|
972
|
-
init_handles = flatten_values_to_ir(init_args)
|
|
973
1068
|
init_tys = [h.get_type() for h in init_handles]
|
|
974
|
-
init_fe_tys = [a.type for a in init_args]
|
|
975
1069
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
976
1070
|
while_op = self.builder.create_while_op(init_tys, init_handles)
|
|
977
1071
|
# merge the condition region
|
|
@@ -982,7 +1076,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
982
1076
|
for name, val in zip(names, condition_args):
|
|
983
1077
|
self.lscope[name] = val
|
|
984
1078
|
self.local_defs[name] = val
|
|
1079
|
+
self._maybe_set_loc_to_name(val, name)
|
|
985
1080
|
cond = self.visit(node.test)
|
|
1081
|
+
if isinstance(cond, language.condition):
|
|
1082
|
+
if cond.disable_licm:
|
|
1083
|
+
while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
|
|
1084
|
+
cond = cond.condition
|
|
986
1085
|
self.builder.set_insertion_point_to_end(before_block)
|
|
987
1086
|
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
|
988
1087
|
self.builder.create_condition_op(cond.handle, block_args)
|
|
@@ -996,16 +1095,13 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
996
1095
|
for name, val in zip(names, body_args):
|
|
997
1096
|
self.lscope[name] = val
|
|
998
1097
|
self.local_defs[name] = val
|
|
1098
|
+
self._maybe_set_loc_to_name(val, name)
|
|
999
1099
|
self.scf_stack.append(node)
|
|
1000
1100
|
self.visit_compound_statement(node.body)
|
|
1001
1101
|
self.scf_stack.pop()
|
|
1002
|
-
loop_defs = self.local_defs
|
|
1003
|
-
yields = []
|
|
1004
|
-
for name in loop_defs:
|
|
1005
|
-
if name in liveins:
|
|
1006
|
-
loop_defs[name]._flatten_ir(yields)
|
|
1007
1102
|
|
|
1008
|
-
self.
|
|
1103
|
+
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
|
|
1104
|
+
self.builder.create_yield_op(yield_handles)
|
|
1009
1105
|
|
|
1010
1106
|
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
|
1011
1107
|
result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
|
|
@@ -1013,6 +1109,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1013
1109
|
for name, new_def in zip(names, result_vals):
|
|
1014
1110
|
self.lscope[name] = new_def
|
|
1015
1111
|
self.local_defs[name] = new_def
|
|
1112
|
+
self._maybe_set_loc_to_name(new_def, name)
|
|
1016
1113
|
|
|
1017
1114
|
for stmt in node.orelse:
|
|
1018
1115
|
assert False, "Not implemented"
|
|
@@ -1022,16 +1119,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1022
1119
|
assert isinstance(node.ctx, ast.Load)
|
|
1023
1120
|
lhs = self.visit(node.value)
|
|
1024
1121
|
slices = self.visit(node.slice)
|
|
1025
|
-
if
|
|
1026
|
-
return lhs.__getitem__
|
|
1122
|
+
if _is_triton_value(lhs):
|
|
1123
|
+
return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
|
|
1027
1124
|
return lhs[slices]
|
|
1028
1125
|
|
|
1029
1126
|
def visit_Subscript_Store(self, node, value):
|
|
1030
|
-
|
|
1031
|
-
lhs = self.visit(node.value)
|
|
1032
|
-
slices = self.visit(node.slice)
|
|
1033
|
-
assert isinstance(lhs, language.tuple)
|
|
1034
|
-
lhs.__setitem__(slices, value)
|
|
1127
|
+
raise NotImplementedError("__setitem__ is not supported in triton")
|
|
1035
1128
|
|
|
1036
1129
|
def visit_Subscript(self, node):
|
|
1037
1130
|
return self.visit_Subscript_Load(node)
|
|
@@ -1057,6 +1150,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1057
1150
|
disallow_acc_multi_buffer = False
|
|
1058
1151
|
flatten = False
|
|
1059
1152
|
warp_specialize = False
|
|
1153
|
+
disable_licm = False
|
|
1060
1154
|
if IteratorClass is language.range:
|
|
1061
1155
|
iterator = IteratorClass(*iter_args, **iter_kwargs)
|
|
1062
1156
|
# visit iterator arguments
|
|
@@ -1070,6 +1164,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1070
1164
|
disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
|
|
1071
1165
|
flatten = iterator.flatten
|
|
1072
1166
|
warp_specialize = iterator.warp_specialize
|
|
1167
|
+
disable_licm = iterator.disable_licm
|
|
1073
1168
|
elif IteratorClass is range:
|
|
1074
1169
|
# visit iterator arguments
|
|
1075
1170
|
# note: only `range` iterator is supported now
|
|
@@ -1111,34 +1206,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1111
1206
|
liveins, insert_block = sr
|
|
1112
1207
|
ip, last_loc = self._get_insertion_point_and_loc()
|
|
1113
1208
|
|
|
1114
|
-
|
|
1115
|
-
block = self.builder.create_block()
|
|
1116
|
-
self.builder.set_insertion_point_to_start(block)
|
|
1117
|
-
# dry visit loop body
|
|
1118
|
-
self.scf_stack.append(node)
|
|
1119
|
-
self.visit_compound_statement(node.body)
|
|
1120
|
-
self.scf_stack.pop()
|
|
1121
|
-
block.erase()
|
|
1122
|
-
|
|
1123
|
-
# If a variable (name) is defined in both its parent & itself, then it's
|
|
1124
|
-
# a loop-carried variable. (They must be of the same type)
|
|
1125
|
-
init_args = []
|
|
1126
|
-
yields = []
|
|
1127
|
-
names = []
|
|
1128
|
-
for name in self.local_defs:
|
|
1129
|
-
if name in liveins:
|
|
1130
|
-
loop_val = self.local_defs[name]
|
|
1131
|
-
live_val = liveins[name]
|
|
1132
|
-
self._verify_loop_carried_variable(name, loop_val, live_val)
|
|
1133
|
-
|
|
1134
|
-
names.append(name)
|
|
1135
|
-
init_args.append(live_val)
|
|
1136
|
-
yields.append(loop_val)
|
|
1209
|
+
names, init_handles, init_tys = self._find_carries(node, liveins)
|
|
1137
1210
|
|
|
1138
1211
|
# create ForOp
|
|
1139
1212
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
1140
|
-
init_handles = flatten_values_to_ir(init_args)
|
|
1141
|
-
init_tys = [v.type for v in init_args]
|
|
1142
1213
|
for_op = self.builder.create_for_op(lb, ub, step, init_handles)
|
|
1143
1214
|
if _unwrap_if_constexpr(num_stages) is not None:
|
|
1144
1215
|
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
|
|
@@ -1150,30 +1221,23 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1150
1221
|
for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
|
|
1151
1222
|
if warp_specialize:
|
|
1152
1223
|
for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
|
|
1224
|
+
if disable_licm:
|
|
1225
|
+
for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
|
|
1153
1226
|
|
|
1154
1227
|
self.scf_stack.append(node)
|
|
1155
1228
|
for_op_body = for_op.get_body(0)
|
|
1156
1229
|
self.builder.set_insertion_point_to_start(for_op_body)
|
|
1157
|
-
# reset local scope to not pick up local defs from the previous dry run.
|
|
1158
|
-
self.lscope = liveins.copy()
|
|
1159
|
-
self.local_defs = {}
|
|
1160
1230
|
block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
|
|
1161
1231
|
block_args = unflatten_ir_values(block_handles, init_tys)
|
|
1162
1232
|
for name, val in zip(names, block_args):
|
|
1233
|
+
self._maybe_set_loc_to_name(val, name)
|
|
1163
1234
|
self.set_value(name, val)
|
|
1164
1235
|
self.visit_compound_statement(node.body)
|
|
1165
1236
|
self.scf_stack.pop()
|
|
1166
|
-
|
|
1167
|
-
for name in self.local_defs:
|
|
1168
|
-
if name in liveins:
|
|
1169
|
-
local = self.local_defs[name]
|
|
1170
|
-
if isinstance(local, constexpr):
|
|
1171
|
-
local = self.semantic.to_tensor(local)
|
|
1172
|
-
yields.append(local)
|
|
1237
|
+
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
|
|
1173
1238
|
|
|
1174
1239
|
# create YieldOp
|
|
1175
|
-
if len(
|
|
1176
|
-
yield_handles = flatten_values_to_ir(yields)
|
|
1240
|
+
if len(yield_handles) > 0:
|
|
1177
1241
|
self.builder.create_yield_op(yield_handles)
|
|
1178
1242
|
for_op_region = for_op_body.get_parent()
|
|
1179
1243
|
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
|
@@ -1186,12 +1250,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1186
1250
|
iv = self.builder.create_add(iv, lb)
|
|
1187
1251
|
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
|
1188
1252
|
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
|
|
1253
|
+
self._maybe_set_loc_to_name(iv, node.target.id)
|
|
1189
1254
|
|
|
1190
1255
|
# update lscope & local_defs (ForOp defines new values)
|
|
1191
1256
|
result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
|
|
1192
1257
|
result_values = unflatten_ir_values(result_handles, init_tys)
|
|
1193
1258
|
for name, val in zip(names, result_values):
|
|
1194
1259
|
self.set_value(name, val)
|
|
1260
|
+
self._maybe_set_loc_to_name(val, name)
|
|
1195
1261
|
|
|
1196
1262
|
for stmt in node.orelse:
|
|
1197
1263
|
assert False, "Don't know what to do with else after for"
|
|
@@ -1214,7 +1280,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1214
1280
|
msg = self.visit(node.msg) if node.msg is not None else ""
|
|
1215
1281
|
return language.core.device_assert(test, msg, _semantic=self.semantic)
|
|
1216
1282
|
|
|
1217
|
-
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
|
1283
|
+
def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
|
|
1218
1284
|
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
|
1219
1285
|
args = [args[name] for name in fn.arg_names]
|
|
1220
1286
|
for i, arg in enumerate(args):
|
|
@@ -1225,7 +1291,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1225
1291
|
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
|
|
1226
1292
|
args_val = [get_iterable_path(args, path) for path in args_path]
|
|
1227
1293
|
# mangle
|
|
1228
|
-
|
|
1294
|
+
caller_context = caller_context or self.caller_context
|
|
1295
|
+
fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
|
|
1229
1296
|
# generate function def if necessary
|
|
1230
1297
|
if not self.module.has_function(fn_name):
|
|
1231
1298
|
# If the callee is not set, we use the same debug setting as the caller
|
|
@@ -1240,7 +1307,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1240
1307
|
function_name=fn_name, function_types=self.function_ret_types,
|
|
1241
1308
|
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
|
1242
1309
|
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
|
|
1243
|
-
module_map=self.builder.module_map
|
|
1310
|
+
module_map=self.builder.module_map, caller_context=caller_context,
|
|
1311
|
+
is_gluon=self.is_gluon)
|
|
1244
1312
|
try:
|
|
1245
1313
|
generator.visit(fn.parse())
|
|
1246
1314
|
except Exception as e:
|
|
@@ -1261,32 +1329,23 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1261
1329
|
handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
|
|
1262
1330
|
return next(unflatten_ir_values(handles, [callee_ret_type]))
|
|
1263
1331
|
|
|
1264
|
-
def
|
|
1265
|
-
fn
|
|
1266
|
-
if not isinstance(fn, BoundJITMethod):
|
|
1267
|
-
static_implementation = self.statically_implemented_functions.get(fn)
|
|
1268
|
-
if static_implementation is not None:
|
|
1269
|
-
return static_implementation(self, node)
|
|
1270
|
-
|
|
1271
|
-
mur = getattr(fn, '_must_use_result', False)
|
|
1272
|
-
if mur and getattr(node, '_is_unused', False):
|
|
1273
|
-
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
|
|
1274
|
-
if isinstance(mur, str):
|
|
1275
|
-
error_message.append(mur)
|
|
1276
|
-
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
|
|
1277
|
-
|
|
1278
|
-
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
|
1279
|
-
args = [self.visit(arg) for arg in node.args]
|
|
1280
|
-
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
|
|
1281
|
-
if isinstance(fn, BoundJITMethod):
|
|
1332
|
+
def call_Function(self, node, fn, args, kws):
|
|
1333
|
+
if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
|
|
1282
1334
|
args.insert(0, fn.__self__)
|
|
1283
1335
|
fn = fn.__func__
|
|
1284
1336
|
if isinstance(fn, JITFunction):
|
|
1285
1337
|
_check_fn_args(node, fn, args)
|
|
1286
1338
|
return self.call_JitFunction(fn, args, kws)
|
|
1287
|
-
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn)
|
|
1288
|
-
|
|
1289
|
-
|
|
1339
|
+
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
|
|
1340
|
+
fn, ConstexprFunction):
|
|
1341
|
+
extra_kwargs = dict()
|
|
1342
|
+
|
|
1343
|
+
if isinstance(fn, ConstexprFunction):
|
|
1344
|
+
sig = inspect.signature(fn.__call__)
|
|
1345
|
+
else:
|
|
1346
|
+
sig = inspect.signature(fn)
|
|
1347
|
+
if '_semantic' in sig.parameters:
|
|
1348
|
+
extra_kwargs["_semantic"] = self.semantic
|
|
1290
1349
|
if '_generator' in sig.parameters:
|
|
1291
1350
|
extra_kwargs['_generator'] = self
|
|
1292
1351
|
try:
|
|
@@ -1304,12 +1363,45 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1304
1363
|
# itself). But when calling a function, we raise as `from e` to
|
|
1305
1364
|
# preserve the traceback of the original error, which may e.g.
|
|
1306
1365
|
# be in core.py.
|
|
1307
|
-
raise CompilationError(self.jit_fn.src, node,
|
|
1366
|
+
raise CompilationError(self.jit_fn.src, node, str(e)) from e
|
|
1308
1367
|
|
|
1309
1368
|
if fn in self.builtin_namespace.values():
|
|
1310
1369
|
args = map(_unwrap_if_constexpr, args)
|
|
1311
1370
|
ret = fn(*args, **kws)
|
|
1312
|
-
|
|
1371
|
+
|
|
1372
|
+
def wrap_constexpr(x):
|
|
1373
|
+
if _is_triton_value(x):
|
|
1374
|
+
return x
|
|
1375
|
+
return constexpr(x)
|
|
1376
|
+
|
|
1377
|
+
if isinstance(ret, (builtins.tuple, language.tuple)):
|
|
1378
|
+
return _apply_to_tuple_values(ret, wrap_constexpr)
|
|
1379
|
+
return wrap_constexpr(ret)
|
|
1380
|
+
|
|
1381
|
+
def call_Method(self, node, fn, fn_self, args, kws):
|
|
1382
|
+
if isinstance(fn, JITFunction):
|
|
1383
|
+
args.insert(0, fn_self)
|
|
1384
|
+
return self.call_Function(node, fn, args, kws)
|
|
1385
|
+
|
|
1386
|
+
def visit_Call(self, node):
|
|
1387
|
+
fn = _unwrap_if_constexpr(self.visit(node.func))
|
|
1388
|
+
if not isinstance(fn, BoundJITMethod):
|
|
1389
|
+
static_implementation = self.statically_implemented_functions.get(fn)
|
|
1390
|
+
if static_implementation is not None:
|
|
1391
|
+
return static_implementation(self, node)
|
|
1392
|
+
|
|
1393
|
+
mur = getattr(fn, '_must_use_result', False)
|
|
1394
|
+
if mur and getattr(node, '_is_unused', False):
|
|
1395
|
+
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
|
|
1396
|
+
if isinstance(mur, str):
|
|
1397
|
+
error_message.append(mur)
|
|
1398
|
+
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
|
|
1399
|
+
|
|
1400
|
+
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
|
1401
|
+
args = [self.visit(arg) for arg in node.args]
|
|
1402
|
+
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
|
|
1403
|
+
|
|
1404
|
+
return self.call_Function(node, fn, args, kws)
|
|
1313
1405
|
|
|
1314
1406
|
def visit_Constant(self, node):
|
|
1315
1407
|
return constexpr(node.value)
|
|
@@ -1373,7 +1465,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1373
1465
|
if _is_triton_tensor(lhs) and node.attr == "T":
|
|
1374
1466
|
return self.semantic.permute(lhs, (1, 0))
|
|
1375
1467
|
# NOTE: special case ".value" for BC
|
|
1376
|
-
if isinstance(lhs, constexpr) and node.attr
|
|
1468
|
+
if isinstance(lhs, constexpr) and node.attr not in ("value", "type"):
|
|
1377
1469
|
lhs = lhs.value
|
|
1378
1470
|
attr = getattr(lhs, node.attr)
|
|
1379
1471
|
if _is_triton_value(lhs) and isinstance(attr, JITFunction):
|
|
@@ -1417,7 +1509,11 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1417
1509
|
last_loc = self.builder.get_loc()
|
|
1418
1510
|
self.cur_node = node
|
|
1419
1511
|
if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
|
|
1420
|
-
self.builder.
|
|
1512
|
+
here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
|
|
1513
|
+
if self.name_loc_as_prefix is not None:
|
|
1514
|
+
self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
|
|
1515
|
+
else:
|
|
1516
|
+
self.builder.set_loc(here_loc)
|
|
1421
1517
|
last_loc = self.builder.get_loc()
|
|
1422
1518
|
try:
|
|
1423
1519
|
ret = super().visit(node)
|
|
@@ -1486,9 +1582,16 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1486
1582
|
|
|
1487
1583
|
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
|
|
1488
1584
|
arg_types = [None] * len(fn.arg_names)
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1585
|
+
const_iter = iter(src.constants.items())
|
|
1586
|
+
kc, vc = next(const_iter, (None, None))
|
|
1587
|
+
|
|
1588
|
+
for i, (ks, v) in enumerate(src.signature.items()):
|
|
1589
|
+
idx = fn.arg_names.index(ks)
|
|
1590
|
+
cexpr = None
|
|
1591
|
+
if kc is not None and kc[0] == i:
|
|
1592
|
+
cexpr = vc
|
|
1593
|
+
kc, vc = next(const_iter, (None, None))
|
|
1594
|
+
arg_types[idx] = str_to_ty(v, cexpr)
|
|
1492
1595
|
prototype = ASTFunction([], arg_types, src.constants, src.attrs)
|
|
1493
1596
|
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1494
1597
|
# query function representation
|
|
@@ -1499,9 +1602,13 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None)
|
|
|
1499
1602
|
proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
|
|
1500
1603
|
generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
|
|
1501
1604
|
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
|
|
1502
|
-
codegen_fns=codegen_fns, module_map=module_map, module=module)
|
|
1605
|
+
codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
|
|
1503
1606
|
generator.visit(fn.parse())
|
|
1504
|
-
|
|
1607
|
+
module = generator.module
|
|
1505
1608
|
# module takes ownership of the context
|
|
1506
|
-
|
|
1507
|
-
|
|
1609
|
+
module.context = context
|
|
1610
|
+
if not module.verify_with_diagnostics():
|
|
1611
|
+
if not fn.is_gluon():
|
|
1612
|
+
print(module)
|
|
1613
|
+
raise RuntimeError("error encountered during parsing")
|
|
1614
|
+
return module
|