triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
|
@@ -1,23 +1,35 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import inspect
|
|
3
3
|
import re
|
|
4
|
-
import sys
|
|
5
4
|
import warnings
|
|
6
5
|
import os
|
|
7
6
|
import textwrap
|
|
8
|
-
|
|
7
|
+
import itertools
|
|
8
|
+
from types import ModuleType
|
|
9
|
+
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
|
|
10
|
+
|
|
9
11
|
from .. import language
|
|
10
12
|
from .._C.libtriton import ir
|
|
11
|
-
from ..language import constexpr,
|
|
12
|
-
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type,
|
|
13
|
-
from ..runtime.jit import
|
|
13
|
+
from ..language import constexpr, semantic, str_to_ty, tensor
|
|
14
|
+
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, base_value, base_type
|
|
15
|
+
from ..runtime.jit import get_jit_fn_file_line
|
|
14
16
|
# ideally we wouldn't need any runtime component
|
|
15
17
|
from ..runtime import JITFunction
|
|
18
|
+
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
|
|
19
|
+
|
|
16
20
|
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
|
|
17
|
-
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def check_identifier_legality(name, type):
|
|
24
|
+
pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$'
|
|
25
|
+
if not re.match(pattern, name):
|
|
26
|
+
raise CompilationError(f"invalid {type} identifier: {name}", name)
|
|
27
|
+
return name
|
|
18
28
|
|
|
19
29
|
|
|
20
30
|
def mangle_ty(ty):
|
|
31
|
+
if ty.is_tuple():
|
|
32
|
+
return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
|
|
21
33
|
if ty.is_ptr():
|
|
22
34
|
return 'P' + mangle_ty(ty.element_ty)
|
|
23
35
|
if ty.is_int():
|
|
@@ -48,7 +60,7 @@ def mangle_fn(name, arg_tys, constants):
|
|
|
48
60
|
|
|
49
61
|
|
|
50
62
|
def _is_triton_value(o: Any) -> bool:
|
|
51
|
-
return isinstance(o,
|
|
63
|
+
return isinstance(o, base_value)
|
|
52
64
|
|
|
53
65
|
|
|
54
66
|
def _is_triton_tensor(o: Any) -> bool:
|
|
@@ -56,7 +68,7 @@ def _is_triton_tensor(o: Any) -> bool:
|
|
|
56
68
|
|
|
57
69
|
|
|
58
70
|
def _is_constexpr(o: Any) -> bool:
|
|
59
|
-
return isinstance(o, constexpr)
|
|
71
|
+
return o is None or isinstance(o, (constexpr, language.core.dtype))
|
|
60
72
|
|
|
61
73
|
|
|
62
74
|
def _is_triton_scalar(o: Any) -> bool:
|
|
@@ -77,6 +89,38 @@ def _check_fn_args(node, fn, args):
|
|
|
77
89
|
)
|
|
78
90
|
|
|
79
91
|
|
|
92
|
+
def _is_namedtuple(val):
|
|
93
|
+
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _apply_to_tuple_values(value, fn):
|
|
97
|
+
if _is_namedtuple(type(value)):
|
|
98
|
+
fields = value._fields
|
|
99
|
+
elif isinstance(value, language.tuple):
|
|
100
|
+
fields = value.type.fields
|
|
101
|
+
else:
|
|
102
|
+
assert False, f"Unsupported type {type(value)}"
|
|
103
|
+
|
|
104
|
+
vals = [fn(v) for v in value]
|
|
105
|
+
types = [v.type for v in vals]
|
|
106
|
+
return language.tuple(vals, language.tuple_type(types, fields))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def flatten_values_to_ir(values: Iterable[base_value]):
|
|
110
|
+
handles = []
|
|
111
|
+
for v in values:
|
|
112
|
+
v._flatten_ir(handles)
|
|
113
|
+
return handles
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
|
|
117
|
+
cursor = 0
|
|
118
|
+
for ty in types:
|
|
119
|
+
value, cursor = ty._unflatten_ir(handles, cursor)
|
|
120
|
+
yield value
|
|
121
|
+
assert cursor == len(handles)
|
|
122
|
+
|
|
123
|
+
|
|
80
124
|
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
|
|
81
125
|
|
|
82
126
|
|
|
@@ -189,11 +233,70 @@ class ContainsReturnChecker(ast.NodeVisitor):
|
|
|
189
233
|
return self.visit(node.func)
|
|
190
234
|
|
|
191
235
|
|
|
236
|
+
class ASTFunction:
|
|
237
|
+
|
|
238
|
+
def __init__(self, ret_types, arg_types, constants, attrs):
|
|
239
|
+
self.ret_types = ret_types
|
|
240
|
+
self.arg_types = arg_types
|
|
241
|
+
self.constants = constants
|
|
242
|
+
self.attrs = attrs
|
|
243
|
+
|
|
244
|
+
def return_types_ir(self, builder: ir.builder):
|
|
245
|
+
ret_types = []
|
|
246
|
+
for ret_ty in self.ret_types:
|
|
247
|
+
if ret_ty is None:
|
|
248
|
+
continue
|
|
249
|
+
ir_ty = ret_ty.to_ir(builder)
|
|
250
|
+
if isinstance(ir_ty, list):
|
|
251
|
+
ret_types.extend(ir_ty)
|
|
252
|
+
else:
|
|
253
|
+
ret_types.append(ir_ty)
|
|
254
|
+
return ret_types
|
|
255
|
+
|
|
256
|
+
def serialize(self, builder: ir.builder):
|
|
257
|
+
# fill up IR values in template
|
|
258
|
+
# > build function
|
|
259
|
+
is_val = lambda path, _: path not in self.constants and _ is not None
|
|
260
|
+
val_paths = list(find_paths_if(self.arg_types, is_val))
|
|
261
|
+
arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
|
|
262
|
+
ret_types = self.return_types_ir(builder)
|
|
263
|
+
return builder.get_function_ty(arg_types, ret_types)
|
|
264
|
+
|
|
265
|
+
def deserialize(self, fn):
|
|
266
|
+
# create "template"
|
|
267
|
+
def make_template(ty):
|
|
268
|
+
if isinstance(ty, (list, tuple, language.tuple_type)):
|
|
269
|
+
return language.tuple([make_template(x) for x in ty], ty)
|
|
270
|
+
return language.constexpr(None)
|
|
271
|
+
|
|
272
|
+
vals = make_template(self.arg_types)
|
|
273
|
+
is_val = lambda path, _: path not in self.constants and _ is not None
|
|
274
|
+
val_paths = list(find_paths_if(self.arg_types, is_val))
|
|
275
|
+
# > set attributes
|
|
276
|
+
for attr_path, attr_specs in self.attrs.items():
|
|
277
|
+
for attr_name, attr_val in attr_specs:
|
|
278
|
+
if attr_path in val_paths:
|
|
279
|
+
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
|
|
280
|
+
for i, path in enumerate(val_paths):
|
|
281
|
+
ty = get_iterable_path(self.arg_types, path)
|
|
282
|
+
if isinstance(ty, nv_tma_desc_type):
|
|
283
|
+
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
|
|
284
|
+
# > add IR values to the template
|
|
285
|
+
for i, path in enumerate(val_paths):
|
|
286
|
+
ty = get_iterable_path(self.arg_types, path)
|
|
287
|
+
set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
|
|
288
|
+
# > add constexpr values to the template
|
|
289
|
+
constants = self.constants
|
|
290
|
+
for path, val in constants.items():
|
|
291
|
+
set_iterable_path(vals, path, language.constexpr(val))
|
|
292
|
+
return vals
|
|
293
|
+
|
|
294
|
+
|
|
192
295
|
class CodeGenerator(ast.NodeVisitor):
|
|
193
296
|
|
|
194
|
-
def __init__(self, context, prototype, gscope,
|
|
195
|
-
|
|
196
|
-
|
|
297
|
+
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
|
|
298
|
+
module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
|
|
299
|
+
file_name: Optional[str] = None, begin_line=0):
|
|
197
300
|
self.context = context
|
|
198
301
|
self.builder = ir.builder(context)
|
|
199
302
|
self.file_name = file_name
|
|
@@ -223,9 +326,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
223
326
|
self.gscope[k] = v
|
|
224
327
|
|
|
225
328
|
self.lscope = {}
|
|
226
|
-
self.attributes = attributes
|
|
227
|
-
self.constants = constants
|
|
228
329
|
self.jit_fn = jit_fn
|
|
330
|
+
# TODO: we currently generate illegal names for non-kernel functions involving constexprs!
|
|
331
|
+
if is_kernel:
|
|
332
|
+
function_name = check_identifier_legality(function_name, "function")
|
|
229
333
|
self.function_name = function_name
|
|
230
334
|
self.is_kernel = is_kernel
|
|
231
335
|
self.cur_node = None
|
|
@@ -260,9 +364,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
260
364
|
if _is_constexpr(val):
|
|
261
365
|
return True
|
|
262
366
|
|
|
263
|
-
if a := self.gscope.get("__annotations__", {}).get(name):
|
|
264
|
-
return _normalize_ty(a) == "constexpr"
|
|
265
|
-
|
|
266
367
|
return False
|
|
267
368
|
|
|
268
369
|
def _define_name_lookup(self):
|
|
@@ -283,6 +384,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
283
384
|
getattr(val, "__triton_builtin__", False), #
|
|
284
385
|
getattr(val, "__module__", "").startswith("triton.language"), #
|
|
285
386
|
isinstance(val, language.dtype), #
|
|
387
|
+
_is_namedtuple(val),
|
|
286
388
|
self._is_constexpr_global(name), #
|
|
287
389
|
# Allow accesses to globals while visiting an ast.arg
|
|
288
390
|
# because you should be able to do
|
|
@@ -295,8 +397,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
295
397
|
textwrap.dedent(f"""\
|
|
296
398
|
Cannot access global variable {name} from within @jit'ed
|
|
297
399
|
function. Triton kernels can only access global variables that
|
|
298
|
-
are
|
|
299
|
-
|
|
400
|
+
are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from
|
|
401
|
+
annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the
|
|
300
402
|
envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
|
|
301
403
|
promise to support this forever.""").replace("\n", " "))
|
|
302
404
|
|
|
@@ -312,7 +414,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
312
414
|
|
|
313
415
|
return name_lookup
|
|
314
416
|
|
|
315
|
-
def set_value(self, name: str, value: Union[
|
|
417
|
+
def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
|
|
316
418
|
''' This function:
|
|
317
419
|
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
|
318
420
|
1. record local defined name (FIXME: should consider control flow)
|
|
@@ -342,7 +444,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
342
444
|
stmts = [stmts]
|
|
343
445
|
for stmt in stmts:
|
|
344
446
|
self.visit(stmt)
|
|
345
|
-
|
|
346
447
|
# Stop parsing as soon as we hit a `return` statement; everything
|
|
347
448
|
# after this is dead code.
|
|
348
449
|
if isinstance(stmt, ast.Return):
|
|
@@ -354,25 +455,30 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
354
455
|
def visit_List(self, node):
|
|
355
456
|
ctx = self.visit(node.ctx)
|
|
356
457
|
assert ctx is None
|
|
357
|
-
elts = [self.visit(elt) for elt in node.elts]
|
|
458
|
+
elts = language.tuple([self.visit(elt) for elt in node.elts])
|
|
358
459
|
return elts
|
|
359
460
|
|
|
360
461
|
# By design, only non-kernel functions can return
|
|
361
462
|
def visit_Return(self, node):
|
|
362
463
|
ret_value = self.visit(node.value)
|
|
464
|
+
handles = []
|
|
465
|
+
|
|
466
|
+
def decay(value):
|
|
467
|
+
if isinstance(value, language.tuple):
|
|
468
|
+
return _apply_to_tuple_values(value, decay)
|
|
469
|
+
elif isinstance(value, (language.constexpr, int, float)):
|
|
470
|
+
return semantic.to_tensor(value, self.builder)
|
|
471
|
+
return value
|
|
472
|
+
|
|
473
|
+
ret_value = decay(ret_value)
|
|
474
|
+
|
|
363
475
|
if ret_value is None:
|
|
364
|
-
self.builder.ret([])
|
|
365
476
|
ret_ty = language.void
|
|
366
|
-
elif isinstance(ret_value, tuple):
|
|
367
|
-
ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value]
|
|
368
|
-
ret_types = [v.type for v in ret_values]
|
|
369
|
-
self.builder.ret([v.handle for v in ret_values])
|
|
370
|
-
ret_ty = tuple(ret_types)
|
|
371
477
|
else:
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
ret_ty =
|
|
375
|
-
|
|
478
|
+
assert isinstance(ret_value, language.core.base_value)
|
|
479
|
+
ret_value._flatten_ir(handles)
|
|
480
|
+
ret_ty = ret_value.type
|
|
481
|
+
self.builder.ret(handles)
|
|
376
482
|
if self.ret_type is None:
|
|
377
483
|
self.ret_type = ret_ty
|
|
378
484
|
elif self.ret_type != ret_ty:
|
|
@@ -383,6 +489,11 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
383
489
|
post_ret_block = self.builder.create_block()
|
|
384
490
|
self.builder.set_insertion_point_to_end(post_ret_block)
|
|
385
491
|
|
|
492
|
+
def visit_Starred(self, node) -> Any:
|
|
493
|
+
args = self.visit(node.value)
|
|
494
|
+
assert isinstance(args, language.core.tuple)
|
|
495
|
+
return args.values
|
|
496
|
+
|
|
386
497
|
def visit_FunctionDef(self, node):
|
|
387
498
|
arg_names, kwarg_names = self.visit(node.args)
|
|
388
499
|
if self.fn:
|
|
@@ -397,7 +508,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
397
508
|
init_node = ast.Assign(targets=[st_target], value=default_value)
|
|
398
509
|
else:
|
|
399
510
|
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
|
400
|
-
|
|
401
511
|
try:
|
|
402
512
|
assert not self.visiting_arg_default_value
|
|
403
513
|
self.visiting_arg_default_value = True
|
|
@@ -407,34 +517,15 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
407
517
|
|
|
408
518
|
# initialize function
|
|
409
519
|
visibility = "public" if self.is_kernel else "private"
|
|
410
|
-
|
|
411
|
-
|
|
520
|
+
fn_ty = self.prototype.serialize(self.builder)
|
|
521
|
+
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline)
|
|
412
522
|
self.module.push_back(self.fn)
|
|
413
523
|
entry = self.fn.add_entry_block()
|
|
414
|
-
arg_values =
|
|
415
|
-
|
|
416
|
-
for i in range(len(arg_names)):
|
|
417
|
-
if i in self.constants:
|
|
418
|
-
cst = self.constants[i]
|
|
419
|
-
if not _is_constexpr(cst):
|
|
420
|
-
cst = constexpr(self.constants[i])
|
|
421
|
-
arg_values.append(cst)
|
|
422
|
-
continue
|
|
423
|
-
else:
|
|
424
|
-
if i in self.attributes:
|
|
425
|
-
for name, value in self.attributes[i]:
|
|
426
|
-
self.fn.set_arg_attr(idx, name, value)
|
|
427
|
-
|
|
428
|
-
# Mark this argument as a pass-by-value TMA descriptor (nvidia)
|
|
429
|
-
if isinstance(self.prototype.param_types[idx], nv_tma_desc_type):
|
|
430
|
-
self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1)
|
|
431
|
-
|
|
432
|
-
arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
|
|
433
|
-
idx += 1
|
|
434
|
-
|
|
435
|
-
insert_pt = self.builder.get_insertion_block()
|
|
524
|
+
arg_values = self.prototype.deserialize(self.fn)
|
|
525
|
+
# bind arguments to symbols
|
|
436
526
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
|
437
527
|
self.set_value(arg_name, arg_value)
|
|
528
|
+
insert_pt = self.builder.get_insertion_block()
|
|
438
529
|
self.builder.set_insertion_point_to_start(entry)
|
|
439
530
|
# visit function body
|
|
440
531
|
self.visit_compound_statement(node.body)
|
|
@@ -445,13 +536,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
445
536
|
self.ret_type = language.void
|
|
446
537
|
self.builder.ret([])
|
|
447
538
|
else:
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
self.
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
])
|
|
539
|
+
if isinstance(self.ret_type, language.tuple_type):
|
|
540
|
+
self.prototype.ret_types = self.ret_type.types
|
|
541
|
+
else:
|
|
542
|
+
self.prototype.ret_types = [self.ret_type]
|
|
543
|
+
self.fn.reset_type(self.prototype.serialize(self.builder))
|
|
544
|
+
self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)])
|
|
455
545
|
self.fn.finalize()
|
|
456
546
|
|
|
457
547
|
if insert_pt:
|
|
@@ -478,37 +568,41 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
478
568
|
if target in self.lscope:
|
|
479
569
|
raise ValueError(f'{target} is already defined.'
|
|
480
570
|
f' constexpr cannot be reassigned.')
|
|
481
|
-
|
|
482
|
-
value = constexpr(value)
|
|
571
|
+
value = constexpr(value)
|
|
483
572
|
self.lscope[target] = value
|
|
484
573
|
return self.lscope[target]
|
|
485
574
|
# default: call visit_Assign
|
|
486
575
|
return self.visit_Assign(node)
|
|
487
576
|
|
|
577
|
+
def assignTarget(self, target, value):
|
|
578
|
+
if isinstance(target, ast.Subscript):
|
|
579
|
+
assert target.ctx.__class__.__name__ == "Store"
|
|
580
|
+
return self.visit_Subscript_Store(target, value)
|
|
581
|
+
if isinstance(target, ast.Tuple):
|
|
582
|
+
assert target.ctx.__class__.__name__ == "Store"
|
|
583
|
+
for i, name in enumerate(target.elts):
|
|
584
|
+
self.set_value(self.visit(name), value.values[i])
|
|
585
|
+
return
|
|
586
|
+
assert isinstance(target, ast.Name)
|
|
587
|
+
self.set_value(self.visit(target), value)
|
|
588
|
+
|
|
488
589
|
def visit_Assign(self, node):
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
_names += [self.visit(target)]
|
|
495
|
-
if len(_names) > 1:
|
|
496
|
-
raise self._unsupported(node, "simultaneous multiple assignment is not supported.")
|
|
497
|
-
names = _names[0]
|
|
498
|
-
values = self.visit(node.value)
|
|
499
|
-
if not _is_list_like(names):
|
|
500
|
-
names = [names]
|
|
501
|
-
if not _is_list_like(values):
|
|
502
|
-
values = [values]
|
|
503
|
-
native_nontensor_types = (language.dtype, )
|
|
504
|
-
for name, value in zip(names, values):
|
|
505
|
-
# by default, constexpr are assigned into python variable
|
|
590
|
+
# construct values to assign
|
|
591
|
+
def _sanitize_value(value):
|
|
592
|
+
if isinstance(value, language.tuple):
|
|
593
|
+
return _apply_to_tuple_values(value, _sanitize_value)
|
|
594
|
+
native_nontensor_types = (language.dtype, language.tuple)
|
|
506
595
|
value = _unwrap_if_constexpr(value)
|
|
507
596
|
if value is not None and \
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
value =
|
|
511
|
-
|
|
597
|
+
not _is_triton_value(value) and \
|
|
598
|
+
not isinstance(value, native_nontensor_types):
|
|
599
|
+
value = semantic.to_tensor(value, self.builder)
|
|
600
|
+
return value
|
|
601
|
+
|
|
602
|
+
values = _sanitize_value(self.visit(node.value))
|
|
603
|
+
targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
|
|
604
|
+
assert len(targets) == 1
|
|
605
|
+
self.assignTarget(targets[0], values)
|
|
512
606
|
|
|
513
607
|
def visit_AugAssign(self, node):
|
|
514
608
|
name = node.target.id
|
|
@@ -531,7 +625,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
531
625
|
|
|
532
626
|
def visit_Tuple(self, node):
|
|
533
627
|
args = [self.visit(x) for x in node.elts]
|
|
534
|
-
return tuple(args)
|
|
628
|
+
return language.tuple(args)
|
|
535
629
|
|
|
536
630
|
def _apply_binary_method(self, method_name, lhs, rhs):
|
|
537
631
|
# TODO: raise something meaningful if getattr fails below, esp for reverse method
|
|
@@ -584,21 +678,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
584
678
|
|
|
585
679
|
# update block arguments
|
|
586
680
|
names = []
|
|
587
|
-
ret_types = []
|
|
588
|
-
ir_ret_types = []
|
|
589
681
|
# variables in livein whose value is updated in `if`
|
|
590
682
|
for name in liveins:
|
|
591
683
|
# check type
|
|
592
684
|
for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
|
|
593
685
|
if name in defs:
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
f'
|
|
686
|
+
type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721
|
|
687
|
+
assert type_equal and defs[name].type == liveins[name].type, \
|
|
688
|
+
f'initial value for `{name}` is of type {liveins[name]}, '\
|
|
689
|
+
f'but the {block_name} block redefines it as {defs[name]}'
|
|
597
690
|
if name in then_defs or name in else_defs:
|
|
598
691
|
names.append(name)
|
|
599
|
-
ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
|
|
600
|
-
ir_ret_types.append(then_defs[name].handle.get_type() if name in
|
|
601
|
-
then_defs else else_defs[name].handle.get_type())
|
|
602
692
|
# variable defined in then but not in else
|
|
603
693
|
if name in then_defs and name not in else_defs:
|
|
604
694
|
else_defs[name] = liveins[name]
|
|
@@ -610,16 +700,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
610
700
|
for name in sorted(then_defs.keys() & else_defs.keys()):
|
|
611
701
|
if name in names:
|
|
612
702
|
continue
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
703
|
+
then_val = then_defs[name]
|
|
704
|
+
then_ty = then_val.type
|
|
705
|
+
else_val = else_defs[name]
|
|
706
|
+
else_ty = else_val.type
|
|
707
|
+
type_equal = type(then_val) == type(else_val) # noqa: E721
|
|
708
|
+
assert type_equal and then_ty == else_ty, \
|
|
616
709
|
f'Mismatched type for {name} between then block ({then_ty}) '\
|
|
617
710
|
f'and else block ({else_ty})'
|
|
618
711
|
names.append(name)
|
|
619
|
-
ret_types.append(then_ty)
|
|
620
|
-
ir_ret_types.append(then_defs[name].handle.get_type())
|
|
621
712
|
|
|
622
|
-
return then_defs, else_defs, then_block, else_block, names
|
|
713
|
+
return then_defs, else_defs, then_block, else_block, names
|
|
623
714
|
|
|
624
715
|
def visit_if_top_level(self, cond, node):
|
|
625
716
|
with enter_sub_region(self) as sr:
|
|
@@ -630,27 +721,34 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
630
721
|
self.builder.set_insertion_point_to_end(ip_block)
|
|
631
722
|
self.builder.create_cond_branch(cond.handle, then_block, else_block)
|
|
632
723
|
# visit then and else blocks
|
|
633
|
-
then_defs, else_defs, then_block, else_block, names
|
|
724
|
+
then_defs, else_defs, then_block, else_block, names = \
|
|
634
725
|
self.visit_then_else_blocks(node, liveins, then_block, else_block)
|
|
635
726
|
# create basic-block after conditional
|
|
636
727
|
endif_block = self.builder.create_block()
|
|
637
728
|
# then terminator
|
|
638
729
|
self.builder.set_insertion_point_to_end(then_block)
|
|
639
730
|
assert not then_block.has_terminator(), f"{then_block}"
|
|
640
|
-
|
|
731
|
+
then_handles = flatten_values_to_ir(then_defs[name] for name in names)
|
|
732
|
+
self.builder.create_branch(endif_block, then_handles)
|
|
641
733
|
# else terminator
|
|
642
734
|
self.builder.set_insertion_point_to_end(else_block)
|
|
643
735
|
assert not else_block.has_terminator(), f"{else_block}"
|
|
644
|
-
|
|
645
|
-
|
|
736
|
+
else_handles = flatten_values_to_ir(else_defs[name] for name in names)
|
|
737
|
+
self.builder.create_branch(endif_block, else_handles)
|
|
738
|
+
assert len(then_handles) == len(else_handles)
|
|
739
|
+
for then_h, else_h in zip(then_handles, else_handles):
|
|
740
|
+
ty = then_h.get_type()
|
|
741
|
+
assert ty == else_h.get_type()
|
|
646
742
|
endif_block.add_argument(ty)
|
|
647
743
|
|
|
648
744
|
# change block
|
|
649
745
|
self.builder.set_insertion_point_to_start(endif_block)
|
|
650
746
|
# update value
|
|
651
|
-
for i
|
|
652
|
-
|
|
653
|
-
|
|
747
|
+
res_handles = [endif_block.arg(i) for i in range(len(then_handles))]
|
|
748
|
+
types = [then_defs[name].type for name in names]
|
|
749
|
+
new_values = unflatten_ir_values(res_handles, types)
|
|
750
|
+
for name, new_value in zip(names, new_values):
|
|
751
|
+
self.set_value(name, new_value)
|
|
654
752
|
|
|
655
753
|
# TODO: refactor
|
|
656
754
|
def visit_if_scf(self, cond, node):
|
|
@@ -659,26 +757,30 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
659
757
|
ip, last_loc = self._get_insertion_point_and_loc()
|
|
660
758
|
then_block = self.builder.create_block()
|
|
661
759
|
else_block = self.builder.create_block() if node.orelse else None
|
|
662
|
-
then_defs, else_defs, then_block, else_block, names
|
|
760
|
+
then_defs, else_defs, then_block, else_block, names = \
|
|
663
761
|
self.visit_then_else_blocks(node, liveins, then_block, else_block)
|
|
664
762
|
# create if op
|
|
763
|
+
then_handles = flatten_values_to_ir(then_defs[name] for name in names)
|
|
665
764
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
666
|
-
if_op = self.builder.create_if_op([
|
|
765
|
+
if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
|
|
667
766
|
then_block.merge_block_before(if_op.get_then_block())
|
|
668
767
|
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
|
669
768
|
if len(names) > 0:
|
|
670
|
-
self.builder.create_yield_op(
|
|
769
|
+
self.builder.create_yield_op(then_handles)
|
|
671
770
|
if not node.orelse:
|
|
672
771
|
else_block = if_op.get_else_block()
|
|
673
772
|
else:
|
|
674
773
|
else_block.merge_block_before(if_op.get_else_block())
|
|
675
774
|
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
|
676
775
|
if len(names) > 0:
|
|
677
|
-
|
|
776
|
+
else_handles = flatten_values_to_ir(else_defs[name] for name in names)
|
|
777
|
+
self.builder.create_yield_op(else_handles)
|
|
678
778
|
# update values
|
|
679
|
-
for i
|
|
680
|
-
|
|
681
|
-
|
|
779
|
+
res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
|
|
780
|
+
types = [then_defs[name].type for name in names]
|
|
781
|
+
new_values = unflatten_ir_values(res_handles, types)
|
|
782
|
+
for name, new_value in zip(names, new_values):
|
|
783
|
+
self.set_value(name, new_value)
|
|
682
784
|
|
|
683
785
|
def visit_If(self, node):
|
|
684
786
|
cond = self.visit(node.test)
|
|
@@ -717,14 +819,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
717
819
|
|
|
718
820
|
then_block = self.builder.create_block()
|
|
719
821
|
self.builder.set_insertion_point_to_start(then_block)
|
|
720
|
-
then_val =
|
|
822
|
+
then_val = semantic.to_tensor(self.visit(node.body), self.builder)
|
|
721
823
|
then_block = self.builder.get_insertion_block()
|
|
722
824
|
|
|
723
825
|
else_block = self.builder.create_block()
|
|
724
826
|
self.builder.set_insertion_point_to_start(else_block)
|
|
725
827
|
# do not need to reset lscope since
|
|
726
828
|
# ternary expressions cannot define new variables
|
|
727
|
-
else_val =
|
|
829
|
+
else_val = semantic.to_tensor(self.visit(node.orelse), self.builder)
|
|
728
830
|
else_block = self.builder.get_insertion_block()
|
|
729
831
|
|
|
730
832
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
@@ -804,7 +906,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
804
906
|
def _verify_loop_carried_variable(self, name, loop_val, live_val):
|
|
805
907
|
assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
|
|
806
908
|
assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
|
|
807
|
-
assert type(loop_val)
|
|
909
|
+
assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type'
|
|
808
910
|
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
|
|
809
911
|
f'Loop-carried variable {name} has initial type {live_val.type} '\
|
|
810
912
|
f'but is re-assigned to {loop_val.type} in loop! '\
|
|
@@ -827,7 +929,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
827
929
|
|
|
828
930
|
# collect loop-carried values
|
|
829
931
|
names = []
|
|
830
|
-
ret_types = []
|
|
831
932
|
init_args = []
|
|
832
933
|
for name in loop_defs:
|
|
833
934
|
if name in liveins:
|
|
@@ -838,32 +939,35 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
838
939
|
|
|
839
940
|
# these are loop-carried values
|
|
840
941
|
names.append(name)
|
|
841
|
-
ret_types.append(loop_val.type)
|
|
842
942
|
init_args.append(live_val)
|
|
843
943
|
|
|
944
|
+
init_handles = flatten_values_to_ir(init_args)
|
|
945
|
+
init_tys = [h.get_type() for h in init_handles]
|
|
946
|
+
init_fe_tys = [a.type for a in init_args]
|
|
844
947
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
845
|
-
while_op = self.builder.create_while_op(
|
|
846
|
-
[arg.handle for arg in init_args])
|
|
948
|
+
while_op = self.builder.create_while_op(init_tys, init_handles)
|
|
847
949
|
# merge the condition region
|
|
848
|
-
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
|
849
|
-
[ty.to_ir(self.builder) for ty in ret_types])
|
|
950
|
+
before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys)
|
|
850
951
|
self.builder.set_insertion_point_to_start(before_block)
|
|
851
|
-
for i
|
|
852
|
-
|
|
853
|
-
|
|
952
|
+
block_args = [before_block.arg(i) for i in range(len(init_handles))]
|
|
953
|
+
condition_args = unflatten_ir_values(block_args, init_fe_tys)
|
|
954
|
+
for name, val in zip(names, condition_args):
|
|
955
|
+
self.lscope[name] = val
|
|
956
|
+
self.local_defs[name] = val
|
|
854
957
|
cond = self.visit(node.test)
|
|
855
958
|
self.builder.set_insertion_point_to_end(before_block)
|
|
856
959
|
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
|
857
|
-
self.builder.create_condition_op(cond.handle,
|
|
960
|
+
self.builder.create_condition_op(cond.handle, block_args)
|
|
858
961
|
# merge the loop body
|
|
859
|
-
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
|
860
|
-
[ty.to_ir(self.builder) for ty in ret_types])
|
|
962
|
+
after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys)
|
|
861
963
|
|
|
862
964
|
# generate loop body
|
|
863
965
|
self.builder.set_insertion_point_to_start(after_block)
|
|
864
|
-
for i
|
|
865
|
-
|
|
866
|
-
|
|
966
|
+
body_handles = [after_block.arg(i) for i in range(len(init_handles))]
|
|
967
|
+
body_args = unflatten_ir_values(body_handles, init_fe_tys)
|
|
968
|
+
for name, val in zip(names, body_args):
|
|
969
|
+
self.lscope[name] = val
|
|
970
|
+
self.local_defs[name] = val
|
|
867
971
|
self.scf_stack.append(node)
|
|
868
972
|
self.visit_compound_statement(node.body)
|
|
869
973
|
self.scf_stack.pop()
|
|
@@ -871,12 +975,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
871
975
|
yields = []
|
|
872
976
|
for name in loop_defs:
|
|
873
977
|
if name in liveins:
|
|
874
|
-
|
|
875
|
-
|
|
978
|
+
loop_defs[name]._flatten_ir(yields)
|
|
979
|
+
|
|
980
|
+
self.builder.create_yield_op(yields)
|
|
876
981
|
|
|
877
982
|
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
|
878
|
-
for i
|
|
879
|
-
|
|
983
|
+
result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
|
|
984
|
+
result_vals = unflatten_ir_values(result_handles, init_fe_tys)
|
|
985
|
+
for name, new_def in zip(names, result_vals):
|
|
880
986
|
self.lscope[name] = new_def
|
|
881
987
|
self.local_defs[name] = new_def
|
|
882
988
|
|
|
@@ -884,7 +990,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
884
990
|
assert False, "Not implemented"
|
|
885
991
|
ast.NodeVisitor.generic_visit(self, stmt)
|
|
886
992
|
|
|
887
|
-
def
|
|
993
|
+
def visit_Subscript_Load(self, node):
|
|
888
994
|
assert node.ctx.__class__.__name__ == "Load"
|
|
889
995
|
lhs = self.visit(node.value)
|
|
890
996
|
slices = self.visit(node.slice)
|
|
@@ -892,6 +998,16 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
892
998
|
return lhs.__getitem__(slices, _builder=self.builder)
|
|
893
999
|
return lhs[slices]
|
|
894
1000
|
|
|
1001
|
+
def visit_Subscript_Store(self, node, value):
|
|
1002
|
+
assert node.ctx.__class__.__name__ == "Store"
|
|
1003
|
+
lhs = self.visit(node.value)
|
|
1004
|
+
slices = self.visit(node.slice)
|
|
1005
|
+
assert isinstance(lhs, language.tuple)
|
|
1006
|
+
lhs.__setitem__(slices, value)
|
|
1007
|
+
|
|
1008
|
+
def visit_Subscript(self, node):
|
|
1009
|
+
return self.visit_Subscript_Load(node)
|
|
1010
|
+
|
|
895
1011
|
def visit_ExtSlice(self, node):
|
|
896
1012
|
return [self.visit(dim) for dim in node.dims]
|
|
897
1013
|
|
|
@@ -910,6 +1026,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
910
1026
|
return
|
|
911
1027
|
num_stages = None
|
|
912
1028
|
loop_unroll_factor = None
|
|
1029
|
+
disallow_acc_multi_buffer = False
|
|
1030
|
+
flatten = False
|
|
913
1031
|
if IteratorClass is language.range:
|
|
914
1032
|
iterator = IteratorClass(*iter_args, **iter_kwargs)
|
|
915
1033
|
# visit iterator arguments
|
|
@@ -920,6 +1038,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
920
1038
|
step = iterator.step
|
|
921
1039
|
num_stages = iterator.num_stages
|
|
922
1040
|
loop_unroll_factor = iterator.loop_unroll_factor
|
|
1041
|
+
disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
|
|
1042
|
+
flatten = iterator.flatten
|
|
923
1043
|
elif IteratorClass is range:
|
|
924
1044
|
# visit iterator arguments
|
|
925
1045
|
# note: only `range` iterator is supported now
|
|
@@ -935,14 +1055,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
935
1055
|
step = constexpr(-step.value)
|
|
936
1056
|
negative_step = True
|
|
937
1057
|
lb, ub = ub, lb
|
|
938
|
-
lb =
|
|
939
|
-
ub =
|
|
940
|
-
step =
|
|
1058
|
+
lb = semantic.to_tensor(lb, self.builder)
|
|
1059
|
+
ub = semantic.to_tensor(ub, self.builder)
|
|
1060
|
+
step = semantic.to_tensor(step, self.builder)
|
|
941
1061
|
# induction variable type
|
|
942
1062
|
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
|
|
943
1063
|
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
|
|
944
|
-
iv_type =
|
|
945
|
-
iv_type =
|
|
1064
|
+
iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
|
1065
|
+
iv_type = semantic.integer_promote_impl(iv_type, step.dtype)
|
|
946
1066
|
iv_ir_type = iv_type.to_ir(self.builder)
|
|
947
1067
|
iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
|
|
948
1068
|
# lb/ub/step might be constexpr, we need to cast them to tensor
|
|
@@ -987,34 +1107,47 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
987
1107
|
|
|
988
1108
|
# create ForOp
|
|
989
1109
|
self._set_insertion_point_and_loc(ip, last_loc)
|
|
990
|
-
|
|
991
|
-
|
|
1110
|
+
init_handles = flatten_values_to_ir(init_args)
|
|
1111
|
+
init_tys = [v.type for v in init_args]
|
|
1112
|
+
for_op = self.builder.create_for_op(lb, ub, step, init_handles)
|
|
1113
|
+
if _unwrap_if_constexpr(num_stages) is not None:
|
|
992
1114
|
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
|
|
993
|
-
if loop_unroll_factor is not None:
|
|
1115
|
+
if _unwrap_if_constexpr(loop_unroll_factor) is not None:
|
|
994
1116
|
for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
|
|
1117
|
+
if disallow_acc_multi_buffer:
|
|
1118
|
+
for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
|
|
1119
|
+
if flatten:
|
|
1120
|
+
for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
|
|
995
1121
|
|
|
996
1122
|
self.scf_stack.append(node)
|
|
997
|
-
|
|
1123
|
+
for_op_body = for_op.get_body(0)
|
|
1124
|
+
self.builder.set_insertion_point_to_start(for_op_body)
|
|
998
1125
|
# reset local scope to not pick up local defs from the previous dry run.
|
|
999
1126
|
self.lscope = liveins.copy()
|
|
1000
1127
|
self.local_defs = {}
|
|
1001
|
-
for i
|
|
1002
|
-
|
|
1128
|
+
block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
|
|
1129
|
+
block_args = unflatten_ir_values(block_handles, init_tys)
|
|
1130
|
+
for name, val in zip(names, block_args):
|
|
1131
|
+
self.set_value(name, val)
|
|
1003
1132
|
self.visit_compound_statement(node.body)
|
|
1004
1133
|
self.scf_stack.pop()
|
|
1005
1134
|
yields = []
|
|
1006
1135
|
for name in self.local_defs:
|
|
1007
1136
|
if name in liveins:
|
|
1008
|
-
|
|
1137
|
+
local = self.local_defs[name]
|
|
1138
|
+
if isinstance(local, constexpr):
|
|
1139
|
+
local = semantic.to_tensor(local, self.builder)
|
|
1140
|
+
yields.append(local)
|
|
1009
1141
|
|
|
1010
1142
|
# create YieldOp
|
|
1011
1143
|
if len(yields) > 0:
|
|
1012
|
-
|
|
1013
|
-
|
|
1144
|
+
yield_handles = flatten_values_to_ir(yields)
|
|
1145
|
+
self.builder.create_yield_op(yield_handles)
|
|
1146
|
+
for_op_region = for_op_body.get_parent()
|
|
1014
1147
|
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
|
1015
1148
|
|
|
1016
1149
|
# update induction variable with actual value, and replace all uses
|
|
1017
|
-
self.builder.set_insertion_point_to_start(
|
|
1150
|
+
self.builder.set_insertion_point_to_start(for_op_body)
|
|
1018
1151
|
iv = for_op.get_induction_var()
|
|
1019
1152
|
if negative_step:
|
|
1020
1153
|
iv = self.builder.create_sub(ub, iv)
|
|
@@ -1023,8 +1156,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1023
1156
|
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
|
|
1024
1157
|
|
|
1025
1158
|
# update lscope & local_defs (ForOp defines new values)
|
|
1026
|
-
for i
|
|
1027
|
-
|
|
1159
|
+
result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
|
|
1160
|
+
result_values = unflatten_ir_values(result_handles, init_tys)
|
|
1161
|
+
for name, val in zip(names, result_values):
|
|
1162
|
+
self.set_value(name, val)
|
|
1028
1163
|
|
|
1029
1164
|
for stmt in node.orelse:
|
|
1030
1165
|
assert False, "Don't know what to do with else after for"
|
|
@@ -1034,7 +1169,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1034
1169
|
lower = self.visit(node.lower)
|
|
1035
1170
|
upper = self.visit(node.upper)
|
|
1036
1171
|
step = self.visit(node.step)
|
|
1037
|
-
return slice(lower, upper, step)
|
|
1172
|
+
return language.slice(lower, upper, step)
|
|
1038
1173
|
|
|
1039
1174
|
def visit_Index(self, node):
|
|
1040
1175
|
return self.visit(node.value)
|
|
@@ -1050,24 +1185,28 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1050
1185
|
def call_JitFunction(self, fn: JITFunction, args, kwargs):
|
|
1051
1186
|
args = inspect.getcallargs(fn.fn, *args, **kwargs)
|
|
1052
1187
|
args = [args[name] for name in fn.arg_names]
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
|
1188
|
+
for i, arg in enumerate(args):
|
|
1189
|
+
if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
|
|
1190
|
+
args[i] = language.core.constexpr(arg)
|
|
1191
|
+
args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
|
|
1192
|
+
args_cst = {path: get_iterable_path(args, path) for path in args_cst}
|
|
1193
|
+
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
|
|
1194
|
+
args_val = [get_iterable_path(args, path) for path in args_path]
|
|
1195
|
+
# mangle
|
|
1196
|
+
fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
|
|
1063
1197
|
# generate function def if necessary
|
|
1064
1198
|
if not self.module.has_function(fn_name):
|
|
1065
|
-
prototype = language.function_type([], arg_types)
|
|
1066
1199
|
gscope = fn.__globals__
|
|
1067
1200
|
# If the callee is not set, we use the same debug setting as the caller
|
|
1068
1201
|
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1069
|
-
|
|
1070
|
-
|
|
1202
|
+
arg_types = [
|
|
1203
|
+
language.core.constexpr if arg is None or isinstance(arg,
|
|
1204
|
+
(bool, int, language.core.dtype)) else arg.type
|
|
1205
|
+
for arg in args
|
|
1206
|
+
]
|
|
1207
|
+
prototype = ASTFunction([], arg_types, args_cst, dict())
|
|
1208
|
+
generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn,
|
|
1209
|
+
function_name=fn_name, function_types=self.function_ret_types,
|
|
1071
1210
|
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
|
|
1072
1211
|
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
|
|
1073
1212
|
module_map=self.builder.module_map)
|
|
@@ -1082,17 +1221,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1082
1221
|
else:
|
|
1083
1222
|
callee_ret_type = self.function_ret_types[fn_name]
|
|
1084
1223
|
symbol = self.module.get_function(fn_name)
|
|
1085
|
-
|
|
1086
|
-
|
|
1224
|
+
args_val = [arg.handle for arg in args_val]
|
|
1225
|
+
call_op = self.builder.call(symbol, args_val)
|
|
1226
|
+
if callee_ret_type == language.void:
|
|
1087
1227
|
return None
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
else:
|
|
1091
|
-
# should return a tuple of tl.tensor
|
|
1092
|
-
results = []
|
|
1093
|
-
for i in range(call_op.get_num_results()):
|
|
1094
|
-
results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
|
|
1095
|
-
return tuple(results)
|
|
1228
|
+
handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
|
|
1229
|
+
return next(unflatten_ir_values(handles, [callee_ret_type]))
|
|
1096
1230
|
|
|
1097
1231
|
def visit_Call(self, node):
|
|
1098
1232
|
fn = _unwrap_if_constexpr(self.visit(node.func))
|
|
@@ -1102,6 +1236,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1102
1236
|
|
|
1103
1237
|
kws = dict(self.visit(keyword) for keyword in node.keywords)
|
|
1104
1238
|
args = [self.visit(arg) for arg in node.args]
|
|
1239
|
+
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
|
|
1105
1240
|
if isinstance(fn, JITFunction):
|
|
1106
1241
|
_check_fn_args(node, fn, args)
|
|
1107
1242
|
return self.call_JitFunction(fn, args, kws)
|
|
@@ -1111,7 +1246,11 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1111
1246
|
if '_generator' in sig.parameters:
|
|
1112
1247
|
extra_kwargs['_generator'] = self
|
|
1113
1248
|
try:
|
|
1114
|
-
|
|
1249
|
+
ret = fn(*args, **extra_kwargs, **kws)
|
|
1250
|
+
# builtin functions return plain tuples for readability
|
|
1251
|
+
if isinstance(ret, tuple):
|
|
1252
|
+
ret = language.tuple(ret)
|
|
1253
|
+
return ret
|
|
1115
1254
|
except Exception as e:
|
|
1116
1255
|
# Normally when we raise a CompilationError, we raise it as
|
|
1117
1256
|
# `from None`, because the original fileline from the exception
|
|
@@ -1123,7 +1262,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1123
1262
|
|
|
1124
1263
|
if fn in self.builtin_namespace.values():
|
|
1125
1264
|
args = map(_unwrap_if_constexpr, args)
|
|
1126
|
-
|
|
1265
|
+
ret = fn(*args, **kws)
|
|
1266
|
+
return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret
|
|
1127
1267
|
|
|
1128
1268
|
def visit_Constant(self, node):
|
|
1129
1269
|
return constexpr(node.value)
|
|
@@ -1142,21 +1282,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1142
1282
|
|
|
1143
1283
|
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
|
|
1144
1284
|
|
|
1145
|
-
if sys.version_info < (3, 8):
|
|
1146
|
-
|
|
1147
|
-
def visit_NameConstant(self, node):
|
|
1148
|
-
return constexpr(node.value)
|
|
1149
|
-
|
|
1150
|
-
def visit_Num(self, node):
|
|
1151
|
-
return constexpr(node.n)
|
|
1152
|
-
|
|
1153
|
-
def visit_Str(self, node):
|
|
1154
|
-
return constexpr(ast.literal_eval(node))
|
|
1155
|
-
|
|
1156
1285
|
def visit_Attribute(self, node):
|
|
1157
1286
|
lhs = self.visit(node.value)
|
|
1158
1287
|
if _is_triton_tensor(lhs) and node.attr == "T":
|
|
1159
|
-
return
|
|
1288
|
+
return semantic.permute(lhs, (1, 0), builder=self.builder)
|
|
1160
1289
|
return getattr(lhs, node.attr)
|
|
1161
1290
|
|
|
1162
1291
|
def visit_Expr(self, node):
|
|
@@ -1257,46 +1386,20 @@ class CodeGenerator(ast.NodeVisitor):
|
|
|
1257
1386
|
}
|
|
1258
1387
|
|
|
1259
1388
|
|
|
1260
|
-
def
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
suffix = ''
|
|
1264
|
-
for i, _ in enumerate(signature):
|
|
1265
|
-
suffix += str(i)
|
|
1266
|
-
if i in specialization.equal_to_1:
|
|
1267
|
-
suffix += 'c'
|
|
1268
|
-
if i in specialization.divisibility_16:
|
|
1269
|
-
suffix += 'd'
|
|
1270
|
-
return suffix
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
|
|
1274
|
-
attrs = specialization.attrs
|
|
1275
|
-
# create kernel prototype
|
|
1276
|
-
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
|
|
1277
|
-
constants = {cst_key(key): value for key, value in specialization.constants.items()}
|
|
1278
|
-
# visit kernel AST
|
|
1279
|
-
gscope = fn.__globals__.copy()
|
|
1280
|
-
function_name = fn.repr(specialization)
|
|
1281
|
-
tys = list(specialization.signature.values())
|
|
1282
|
-
new_constants = attrs.get_constants()
|
|
1283
|
-
for k in new_constants:
|
|
1284
|
-
if k in tys and tys[k] == "i1" and new_constants[k] == 1:
|
|
1285
|
-
new_constants[k] = True
|
|
1286
|
-
|
|
1287
|
-
new_attrs = attrs.filter_out_constants()
|
|
1288
|
-
fn_attrs = new_attrs.get_fn_attrs()
|
|
1289
|
-
all_constants = constants.copy()
|
|
1290
|
-
all_constants.update(new_constants)
|
|
1291
|
-
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
|
|
1389
|
+
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
|
|
1390
|
+
arg_types = list(map(str_to_ty, src.signature.values()))
|
|
1391
|
+
prototype = ASTFunction([], arg_types, src.constants, src.attrs)
|
|
1292
1392
|
file_name, begin_line = get_jit_fn_file_line(fn)
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1393
|
+
# query function representation
|
|
1394
|
+
from collections import namedtuple
|
|
1395
|
+
leaves = filter(lambda v: len(v) == 1, src.constants)
|
|
1396
|
+
constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
|
|
1397
|
+
signature = src.signature
|
|
1398
|
+
proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
|
|
1399
|
+
generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(proxy), jit_fn=fn,
|
|
1400
|
+
is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
|
|
1401
|
+
codegen_fns=codegen_fns, module_map=module_map)
|
|
1298
1402
|
generator.visit(fn.parse())
|
|
1299
|
-
|
|
1300
1403
|
ret = generator.module
|
|
1301
1404
|
# module takes ownership of the context
|
|
1302
1405
|
ret.context = context
|