triton-windows 3.2.0.post12__cp313-cp313-win_amd64.whl → 3.3.0a0.post12__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
@@ -1,23 +1,35 @@
1
1
  import ast
2
2
  import inspect
3
3
  import re
4
- import sys
5
4
  import warnings
6
5
  import os
7
6
  import textwrap
8
- from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
7
+ import itertools
8
+ from types import ModuleType
9
+ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
10
+
9
11
  from .. import language
10
12
  from .._C.libtriton import ir
11
- from ..language import constexpr, tensor, str_to_ty
12
- from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value
13
- from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
13
+ from ..language import constexpr, semantic, str_to_ty, tensor
14
+ from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, base_value, base_type
15
+ from ..runtime.jit import get_jit_fn_file_line
14
16
  # ideally we wouldn't need any runtime component
15
17
  from ..runtime import JITFunction
18
+ from .._utils import find_paths_if, get_iterable_path, set_iterable_path
19
+
16
20
  from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
17
- from types import ModuleType
21
+
22
+
23
+ def check_identifier_legality(name, type):
24
+ pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$'
25
+ if not re.match(pattern, name):
26
+ raise CompilationError(f"invalid {type} identifier: {name}", name)
27
+ return name
18
28
 
19
29
 
20
30
  def mangle_ty(ty):
31
+ if ty.is_tuple():
32
+ return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
21
33
  if ty.is_ptr():
22
34
  return 'P' + mangle_ty(ty.element_ty)
23
35
  if ty.is_int():
@@ -48,7 +60,7 @@ def mangle_fn(name, arg_tys, constants):
48
60
 
49
61
 
50
62
  def _is_triton_value(o: Any) -> bool:
51
- return isinstance(o, _value)
63
+ return isinstance(o, base_value)
52
64
 
53
65
 
54
66
  def _is_triton_tensor(o: Any) -> bool:
@@ -56,7 +68,7 @@ def _is_triton_tensor(o: Any) -> bool:
56
68
 
57
69
 
58
70
  def _is_constexpr(o: Any) -> bool:
59
- return isinstance(o, constexpr)
71
+ return o is None or isinstance(o, (constexpr, language.core.dtype))
60
72
 
61
73
 
62
74
  def _is_triton_scalar(o: Any) -> bool:
@@ -77,6 +89,38 @@ def _check_fn_args(node, fn, args):
77
89
  )
78
90
 
79
91
 
92
+ def _is_namedtuple(val):
93
+ return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")
94
+
95
+
96
+ def _apply_to_tuple_values(value, fn):
97
+ if _is_namedtuple(type(value)):
98
+ fields = value._fields
99
+ elif isinstance(value, language.tuple):
100
+ fields = value.type.fields
101
+ else:
102
+ assert False, f"Unsupported type {type(value)}"
103
+
104
+ vals = [fn(v) for v in value]
105
+ types = [v.type for v in vals]
106
+ return language.tuple(vals, language.tuple_type(types, fields))
107
+
108
+
109
+ def flatten_values_to_ir(values: Iterable[base_value]):
110
+ handles = []
111
+ for v in values:
112
+ v._flatten_ir(handles)
113
+ return handles
114
+
115
+
116
+ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
117
+ cursor = 0
118
+ for ty in types:
119
+ value, cursor = ty._unflatten_ir(handles, cursor)
120
+ yield value
121
+ assert cursor == len(handles)
122
+
123
+
80
124
  _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
81
125
 
82
126
 
@@ -189,11 +233,70 @@ class ContainsReturnChecker(ast.NodeVisitor):
189
233
  return self.visit(node.func)
190
234
 
191
235
 
236
+ class ASTFunction:
237
+
238
+ def __init__(self, ret_types, arg_types, constants, attrs):
239
+ self.ret_types = ret_types
240
+ self.arg_types = arg_types
241
+ self.constants = constants
242
+ self.attrs = attrs
243
+
244
+ def return_types_ir(self, builder: ir.builder):
245
+ ret_types = []
246
+ for ret_ty in self.ret_types:
247
+ if ret_ty is None:
248
+ continue
249
+ ir_ty = ret_ty.to_ir(builder)
250
+ if isinstance(ir_ty, list):
251
+ ret_types.extend(ir_ty)
252
+ else:
253
+ ret_types.append(ir_ty)
254
+ return ret_types
255
+
256
+ def serialize(self, builder: ir.builder):
257
+ # fill up IR values in template
258
+ # > build function
259
+ is_val = lambda path, _: path not in self.constants and _ is not None
260
+ val_paths = list(find_paths_if(self.arg_types, is_val))
261
+ arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
262
+ ret_types = self.return_types_ir(builder)
263
+ return builder.get_function_ty(arg_types, ret_types)
264
+
265
+ def deserialize(self, fn):
266
+ # create "template"
267
+ def make_template(ty):
268
+ if isinstance(ty, (list, tuple, language.tuple_type)):
269
+ return language.tuple([make_template(x) for x in ty], ty)
270
+ return language.constexpr(None)
271
+
272
+ vals = make_template(self.arg_types)
273
+ is_val = lambda path, _: path not in self.constants and _ is not None
274
+ val_paths = list(find_paths_if(self.arg_types, is_val))
275
+ # > set attributes
276
+ for attr_path, attr_specs in self.attrs.items():
277
+ for attr_name, attr_val in attr_specs:
278
+ if attr_path in val_paths:
279
+ fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
280
+ for i, path in enumerate(val_paths):
281
+ ty = get_iterable_path(self.arg_types, path)
282
+ if isinstance(ty, nv_tma_desc_type):
283
+ fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
284
+ # > add IR values to the template
285
+ for i, path in enumerate(val_paths):
286
+ ty = get_iterable_path(self.arg_types, path)
287
+ set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
288
+ # > add constexpr values to the template
289
+ constants = self.constants
290
+ for path, val in constants.items():
291
+ set_iterable_path(vals, path, language.constexpr(val))
292
+ return vals
293
+
294
+
192
295
  class CodeGenerator(ast.NodeVisitor):
193
296
 
194
- def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
195
- codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None,
196
- noinline=False, file_name: Optional[str] = None, begin_line=0):
297
+ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
298
+ module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
299
+ file_name: Optional[str] = None, begin_line=0):
197
300
  self.context = context
198
301
  self.builder = ir.builder(context)
199
302
  self.file_name = file_name
@@ -223,9 +326,10 @@ class CodeGenerator(ast.NodeVisitor):
223
326
  self.gscope[k] = v
224
327
 
225
328
  self.lscope = {}
226
- self.attributes = attributes
227
- self.constants = constants
228
329
  self.jit_fn = jit_fn
330
+ # TODO: we currently generate illegal names for non-kernel functions involving constexprs!
331
+ if is_kernel:
332
+ function_name = check_identifier_legality(function_name, "function")
229
333
  self.function_name = function_name
230
334
  self.is_kernel = is_kernel
231
335
  self.cur_node = None
@@ -260,9 +364,6 @@ class CodeGenerator(ast.NodeVisitor):
260
364
  if _is_constexpr(val):
261
365
  return True
262
366
 
263
- if a := self.gscope.get("__annotations__", {}).get(name):
264
- return _normalize_ty(a) == "constexpr"
265
-
266
367
  return False
267
368
 
268
369
  def _define_name_lookup(self):
@@ -283,6 +384,7 @@ class CodeGenerator(ast.NodeVisitor):
283
384
  getattr(val, "__triton_builtin__", False), #
284
385
  getattr(val, "__module__", "").startswith("triton.language"), #
285
386
  isinstance(val, language.dtype), #
387
+ _is_namedtuple(val),
286
388
  self._is_constexpr_global(name), #
287
389
  # Allow accesses to globals while visiting an ast.arg
288
390
  # because you should be able to do
@@ -295,8 +397,8 @@ class CodeGenerator(ast.NodeVisitor):
295
397
  textwrap.dedent(f"""\
296
398
  Cannot access global variable {name} from within @jit'ed
297
399
  function. Triton kernels can only access global variables that
298
- are annotated as constexpr (`x: triton.language.constexpr = 42`
299
- or `x = triton.language.constexpr(42)`). Alternatively, set the
400
+ are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from
401
+ annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the
300
402
  envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
301
403
  promise to support this forever.""").replace("\n", " "))
302
404
 
@@ -312,7 +414,7 @@ class CodeGenerator(ast.NodeVisitor):
312
414
 
313
415
  return name_lookup
314
416
 
315
- def set_value(self, name: str, value: Union[tensor, constexpr]) -> None:
417
+ def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
316
418
  ''' This function:
317
419
  called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
318
420
  1. record local defined name (FIXME: should consider control flow)
@@ -342,7 +444,6 @@ class CodeGenerator(ast.NodeVisitor):
342
444
  stmts = [stmts]
343
445
  for stmt in stmts:
344
446
  self.visit(stmt)
345
-
346
447
  # Stop parsing as soon as we hit a `return` statement; everything
347
448
  # after this is dead code.
348
449
  if isinstance(stmt, ast.Return):
@@ -354,25 +455,30 @@ class CodeGenerator(ast.NodeVisitor):
354
455
  def visit_List(self, node):
355
456
  ctx = self.visit(node.ctx)
356
457
  assert ctx is None
357
- elts = [self.visit(elt) for elt in node.elts]
458
+ elts = language.tuple([self.visit(elt) for elt in node.elts])
358
459
  return elts
359
460
 
360
461
  # By design, only non-kernel functions can return
361
462
  def visit_Return(self, node):
362
463
  ret_value = self.visit(node.value)
464
+ handles = []
465
+
466
+ def decay(value):
467
+ if isinstance(value, language.tuple):
468
+ return _apply_to_tuple_values(value, decay)
469
+ elif isinstance(value, (language.constexpr, int, float)):
470
+ return semantic.to_tensor(value, self.builder)
471
+ return value
472
+
473
+ ret_value = decay(ret_value)
474
+
363
475
  if ret_value is None:
364
- self.builder.ret([])
365
476
  ret_ty = language.void
366
- elif isinstance(ret_value, tuple):
367
- ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value]
368
- ret_types = [v.type for v in ret_values]
369
- self.builder.ret([v.handle for v in ret_values])
370
- ret_ty = tuple(ret_types)
371
477
  else:
372
- ret = language.semantic.to_tensor(ret_value, self.builder)
373
- self.builder.ret([ret.handle])
374
- ret_ty = ret.type
375
-
478
+ assert isinstance(ret_value, language.core.base_value)
479
+ ret_value._flatten_ir(handles)
480
+ ret_ty = ret_value.type
481
+ self.builder.ret(handles)
376
482
  if self.ret_type is None:
377
483
  self.ret_type = ret_ty
378
484
  elif self.ret_type != ret_ty:
@@ -383,6 +489,11 @@ class CodeGenerator(ast.NodeVisitor):
383
489
  post_ret_block = self.builder.create_block()
384
490
  self.builder.set_insertion_point_to_end(post_ret_block)
385
491
 
492
+ def visit_Starred(self, node) -> Any:
493
+ args = self.visit(node.value)
494
+ assert isinstance(args, language.core.tuple)
495
+ return args.values
496
+
386
497
  def visit_FunctionDef(self, node):
387
498
  arg_names, kwarg_names = self.visit(node.args)
388
499
  if self.fn:
@@ -397,7 +508,6 @@ class CodeGenerator(ast.NodeVisitor):
397
508
  init_node = ast.Assign(targets=[st_target], value=default_value)
398
509
  else:
399
510
  init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
400
-
401
511
  try:
402
512
  assert not self.visiting_arg_default_value
403
513
  self.visiting_arg_default_value = True
@@ -407,34 +517,15 @@ class CodeGenerator(ast.NodeVisitor):
407
517
 
408
518
  # initialize function
409
519
  visibility = "public" if self.is_kernel else "private"
410
- self.fn = self.builder.get_or_insert_function(self.module, self.function_name,
411
- self.prototype.to_ir(self.builder), visibility, self.noinline)
520
+ fn_ty = self.prototype.serialize(self.builder)
521
+ self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline)
412
522
  self.module.push_back(self.fn)
413
523
  entry = self.fn.add_entry_block()
414
- arg_values = []
415
- idx = 0
416
- for i in range(len(arg_names)):
417
- if i in self.constants:
418
- cst = self.constants[i]
419
- if not _is_constexpr(cst):
420
- cst = constexpr(self.constants[i])
421
- arg_values.append(cst)
422
- continue
423
- else:
424
- if i in self.attributes:
425
- for name, value in self.attributes[i]:
426
- self.fn.set_arg_attr(idx, name, value)
427
-
428
- # Mark this argument as a pass-by-value TMA descriptor (nvidia)
429
- if isinstance(self.prototype.param_types[idx], nv_tma_desc_type):
430
- self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1)
431
-
432
- arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
433
- idx += 1
434
-
435
- insert_pt = self.builder.get_insertion_block()
524
+ arg_values = self.prototype.deserialize(self.fn)
525
+ # bind arguments to symbols
436
526
  for arg_name, arg_value in zip(arg_names, arg_values):
437
527
  self.set_value(arg_name, arg_value)
528
+ insert_pt = self.builder.get_insertion_block()
438
529
  self.builder.set_insertion_point_to_start(entry)
439
530
  # visit function body
440
531
  self.visit_compound_statement(node.body)
@@ -445,13 +536,12 @@ class CodeGenerator(ast.NodeVisitor):
445
536
  self.ret_type = language.void
446
537
  self.builder.ret([])
447
538
  else:
448
- self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type]
449
- self.fn.reset_type(self.prototype.to_ir(self.builder))
450
- self.builder.ret([
451
- self.builder.create_poison(ty.to_ir(self.builder))
452
- for ty in self.prototype.ret_types
453
- if self.ret_type is not None
454
- ])
539
+ if isinstance(self.ret_type, language.tuple_type):
540
+ self.prototype.ret_types = self.ret_type.types
541
+ else:
542
+ self.prototype.ret_types = [self.ret_type]
543
+ self.fn.reset_type(self.prototype.serialize(self.builder))
544
+ self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)])
455
545
  self.fn.finalize()
456
546
 
457
547
  if insert_pt:
@@ -478,37 +568,41 @@ class CodeGenerator(ast.NodeVisitor):
478
568
  if target in self.lscope:
479
569
  raise ValueError(f'{target} is already defined.'
480
570
  f' constexpr cannot be reassigned.')
481
- if not _is_constexpr(value):
482
- value = constexpr(value)
571
+ value = constexpr(value)
483
572
  self.lscope[target] = value
484
573
  return self.lscope[target]
485
574
  # default: call visit_Assign
486
575
  return self.visit_Assign(node)
487
576
 
577
+ def assignTarget(self, target, value):
578
+ if isinstance(target, ast.Subscript):
579
+ assert target.ctx.__class__.__name__ == "Store"
580
+ return self.visit_Subscript_Store(target, value)
581
+ if isinstance(target, ast.Tuple):
582
+ assert target.ctx.__class__.__name__ == "Store"
583
+ for i, name in enumerate(target.elts):
584
+ self.set_value(self.visit(name), value.values[i])
585
+ return
586
+ assert isinstance(target, ast.Name)
587
+ self.set_value(self.visit(target), value)
588
+
488
589
  def visit_Assign(self, node):
489
- _names = []
490
- if isinstance(node, ast.AnnAssign):
491
- _names += [self.visit(node.target)]
492
- else:
493
- for target in node.targets:
494
- _names += [self.visit(target)]
495
- if len(_names) > 1:
496
- raise self._unsupported(node, "simultaneous multiple assignment is not supported.")
497
- names = _names[0]
498
- values = self.visit(node.value)
499
- if not _is_list_like(names):
500
- names = [names]
501
- if not _is_list_like(values):
502
- values = [values]
503
- native_nontensor_types = (language.dtype, )
504
- for name, value in zip(names, values):
505
- # by default, constexpr are assigned into python variable
590
+ # construct values to assign
591
+ def _sanitize_value(value):
592
+ if isinstance(value, language.tuple):
593
+ return _apply_to_tuple_values(value, _sanitize_value)
594
+ native_nontensor_types = (language.dtype, language.tuple)
506
595
  value = _unwrap_if_constexpr(value)
507
596
  if value is not None and \
508
- not _is_triton_value(value) and \
509
- not isinstance(value, native_nontensor_types):
510
- value = language.semantic.to_tensor(value, self.builder)
511
- self.set_value(name, value)
597
+ not _is_triton_value(value) and \
598
+ not isinstance(value, native_nontensor_types):
599
+ value = semantic.to_tensor(value, self.builder)
600
+ return value
601
+
602
+ values = _sanitize_value(self.visit(node.value))
603
+ targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
604
+ assert len(targets) == 1
605
+ self.assignTarget(targets[0], values)
512
606
 
513
607
  def visit_AugAssign(self, node):
514
608
  name = node.target.id
@@ -531,7 +625,7 @@ class CodeGenerator(ast.NodeVisitor):
531
625
 
532
626
  def visit_Tuple(self, node):
533
627
  args = [self.visit(x) for x in node.elts]
534
- return tuple(args)
628
+ return language.tuple(args)
535
629
 
536
630
  def _apply_binary_method(self, method_name, lhs, rhs):
537
631
  # TODO: raise something meaningful if getattr fails below, esp for reverse method
@@ -584,21 +678,17 @@ class CodeGenerator(ast.NodeVisitor):
584
678
 
585
679
  # update block arguments
586
680
  names = []
587
- ret_types = []
588
- ir_ret_types = []
589
681
  # variables in livein whose value is updated in `if`
590
682
  for name in liveins:
591
683
  # check type
592
684
  for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
593
685
  if name in defs:
594
- assert defs[name].type == liveins[name].type, \
595
- f'initial value for `{name}` is of type {liveins[name].type}, '\
596
- f'but the {block_name} block redefines it as {defs[name].type}'
686
+ type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721
687
+ assert type_equal and defs[name].type == liveins[name].type, \
688
+ f'initial value for `{name}` is of type {liveins[name]}, '\
689
+ f'but the {block_name} block redefines it as {defs[name]}'
597
690
  if name in then_defs or name in else_defs:
598
691
  names.append(name)
599
- ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type)
600
- ir_ret_types.append(then_defs[name].handle.get_type() if name in
601
- then_defs else else_defs[name].handle.get_type())
602
692
  # variable defined in then but not in else
603
693
  if name in then_defs and name not in else_defs:
604
694
  else_defs[name] = liveins[name]
@@ -610,16 +700,17 @@ class CodeGenerator(ast.NodeVisitor):
610
700
  for name in sorted(then_defs.keys() & else_defs.keys()):
611
701
  if name in names:
612
702
  continue
613
- then_ty = then_defs[name].type
614
- else_ty = else_defs[name].type
615
- assert then_ty == else_ty, \
703
+ then_val = then_defs[name]
704
+ then_ty = then_val.type
705
+ else_val = else_defs[name]
706
+ else_ty = else_val.type
707
+ type_equal = type(then_val) == type(else_val) # noqa: E721
708
+ assert type_equal and then_ty == else_ty, \
616
709
  f'Mismatched type for {name} between then block ({then_ty}) '\
617
710
  f'and else block ({else_ty})'
618
711
  names.append(name)
619
- ret_types.append(then_ty)
620
- ir_ret_types.append(then_defs[name].handle.get_type())
621
712
 
622
- return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types
713
+ return then_defs, else_defs, then_block, else_block, names
623
714
 
624
715
  def visit_if_top_level(self, cond, node):
625
716
  with enter_sub_region(self) as sr:
@@ -630,27 +721,34 @@ class CodeGenerator(ast.NodeVisitor):
630
721
  self.builder.set_insertion_point_to_end(ip_block)
631
722
  self.builder.create_cond_branch(cond.handle, then_block, else_block)
632
723
  # visit then and else blocks
633
- then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \
724
+ then_defs, else_defs, then_block, else_block, names = \
634
725
  self.visit_then_else_blocks(node, liveins, then_block, else_block)
635
726
  # create basic-block after conditional
636
727
  endif_block = self.builder.create_block()
637
728
  # then terminator
638
729
  self.builder.set_insertion_point_to_end(then_block)
639
730
  assert not then_block.has_terminator(), f"{then_block}"
640
- self.builder.create_branch(endif_block, [then_defs[n].handle for n in names])
731
+ then_handles = flatten_values_to_ir(then_defs[name] for name in names)
732
+ self.builder.create_branch(endif_block, then_handles)
641
733
  # else terminator
642
734
  self.builder.set_insertion_point_to_end(else_block)
643
735
  assert not else_block.has_terminator(), f"{else_block}"
644
- self.builder.create_branch(endif_block, [else_defs[n].handle for n in names])
645
- for ty in ir_ret_types:
736
+ else_handles = flatten_values_to_ir(else_defs[name] for name in names)
737
+ self.builder.create_branch(endif_block, else_handles)
738
+ assert len(then_handles) == len(else_handles)
739
+ for then_h, else_h in zip(then_handles, else_handles):
740
+ ty = then_h.get_type()
741
+ assert ty == else_h.get_type()
646
742
  endif_block.add_argument(ty)
647
743
 
648
744
  # change block
649
745
  self.builder.set_insertion_point_to_start(endif_block)
650
746
  # update value
651
- for i, name in enumerate(names):
652
- new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
653
- self.set_value(name, new_tensor)
747
+ res_handles = [endif_block.arg(i) for i in range(len(then_handles))]
748
+ types = [then_defs[name].type for name in names]
749
+ new_values = unflatten_ir_values(res_handles, types)
750
+ for name, new_value in zip(names, new_values):
751
+ self.set_value(name, new_value)
654
752
 
655
753
  # TODO: refactor
656
754
  def visit_if_scf(self, cond, node):
@@ -659,26 +757,30 @@ class CodeGenerator(ast.NodeVisitor):
659
757
  ip, last_loc = self._get_insertion_point_and_loc()
660
758
  then_block = self.builder.create_block()
661
759
  else_block = self.builder.create_block() if node.orelse else None
662
- then_defs, else_defs, then_block, else_block, names, ret_types, _ = \
760
+ then_defs, else_defs, then_block, else_block, names = \
663
761
  self.visit_then_else_blocks(node, liveins, then_block, else_block)
664
762
  # create if op
763
+ then_handles = flatten_values_to_ir(then_defs[name] for name in names)
665
764
  self._set_insertion_point_and_loc(ip, last_loc)
666
- if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
765
+ if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
667
766
  then_block.merge_block_before(if_op.get_then_block())
668
767
  self.builder.set_insertion_point_to_end(if_op.get_then_block())
669
768
  if len(names) > 0:
670
- self.builder.create_yield_op([then_defs[n].handle for n in names])
769
+ self.builder.create_yield_op(then_handles)
671
770
  if not node.orelse:
672
771
  else_block = if_op.get_else_block()
673
772
  else:
674
773
  else_block.merge_block_before(if_op.get_else_block())
675
774
  self.builder.set_insertion_point_to_end(if_op.get_else_block())
676
775
  if len(names) > 0:
677
- self.builder.create_yield_op([else_defs[n].handle for n in names])
776
+ else_handles = flatten_values_to_ir(else_defs[name] for name in names)
777
+ self.builder.create_yield_op(else_handles)
678
778
  # update values
679
- for i, name in enumerate(names):
680
- new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i])
681
- self.set_value(name, new_tensor)
779
+ res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
780
+ types = [then_defs[name].type for name in names]
781
+ new_values = unflatten_ir_values(res_handles, types)
782
+ for name, new_value in zip(names, new_values):
783
+ self.set_value(name, new_value)
682
784
 
683
785
  def visit_If(self, node):
684
786
  cond = self.visit(node.test)
@@ -717,14 +819,14 @@ class CodeGenerator(ast.NodeVisitor):
717
819
 
718
820
  then_block = self.builder.create_block()
719
821
  self.builder.set_insertion_point_to_start(then_block)
720
- then_val = language.semantic.to_tensor(self.visit(node.body), self.builder)
822
+ then_val = semantic.to_tensor(self.visit(node.body), self.builder)
721
823
  then_block = self.builder.get_insertion_block()
722
824
 
723
825
  else_block = self.builder.create_block()
724
826
  self.builder.set_insertion_point_to_start(else_block)
725
827
  # do not need to reset lscope since
726
828
  # ternary expressions cannot define new variables
727
- else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder)
829
+ else_val = semantic.to_tensor(self.visit(node.orelse), self.builder)
728
830
  else_block = self.builder.get_insertion_block()
729
831
 
730
832
  self._set_insertion_point_and_loc(ip, last_loc)
@@ -804,7 +906,7 @@ class CodeGenerator(ast.NodeVisitor):
804
906
  def _verify_loop_carried_variable(self, name, loop_val, live_val):
805
907
  assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
806
908
  assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
807
- assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type'
909
+ assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type'
808
910
  assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
809
911
  f'Loop-carried variable {name} has initial type {live_val.type} '\
810
912
  f'but is re-assigned to {loop_val.type} in loop! '\
@@ -827,7 +929,6 @@ class CodeGenerator(ast.NodeVisitor):
827
929
 
828
930
  # collect loop-carried values
829
931
  names = []
830
- ret_types = []
831
932
  init_args = []
832
933
  for name in loop_defs:
833
934
  if name in liveins:
@@ -838,32 +939,35 @@ class CodeGenerator(ast.NodeVisitor):
838
939
 
839
940
  # these are loop-carried values
840
941
  names.append(name)
841
- ret_types.append(loop_val.type)
842
942
  init_args.append(live_val)
843
943
 
944
+ init_handles = flatten_values_to_ir(init_args)
945
+ init_tys = [h.get_type() for h in init_handles]
946
+ init_fe_tys = [a.type for a in init_args]
844
947
  self._set_insertion_point_and_loc(ip, last_loc)
845
- while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
846
- [arg.handle for arg in init_args])
948
+ while_op = self.builder.create_while_op(init_tys, init_handles)
847
949
  # merge the condition region
848
- before_block = self.builder.create_block_with_parent(while_op.get_before(),
849
- [ty.to_ir(self.builder) for ty in ret_types])
950
+ before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys)
850
951
  self.builder.set_insertion_point_to_start(before_block)
851
- for i, name in enumerate(names):
852
- self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i])
853
- self.local_defs[name] = self.lscope[name]
952
+ block_args = [before_block.arg(i) for i in range(len(init_handles))]
953
+ condition_args = unflatten_ir_values(block_args, init_fe_tys)
954
+ for name, val in zip(names, condition_args):
955
+ self.lscope[name] = val
956
+ self.local_defs[name] = val
854
957
  cond = self.visit(node.test)
855
958
  self.builder.set_insertion_point_to_end(before_block)
856
959
  # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
857
- self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
960
+ self.builder.create_condition_op(cond.handle, block_args)
858
961
  # merge the loop body
859
- after_block = self.builder.create_block_with_parent(while_op.get_after(),
860
- [ty.to_ir(self.builder) for ty in ret_types])
962
+ after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys)
861
963
 
862
964
  # generate loop body
863
965
  self.builder.set_insertion_point_to_start(after_block)
864
- for i, name in enumerate(names):
865
- self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i])
866
- self.local_defs[name] = self.lscope[name]
966
+ body_handles = [after_block.arg(i) for i in range(len(init_handles))]
967
+ body_args = unflatten_ir_values(body_handles, init_fe_tys)
968
+ for name, val in zip(names, body_args):
969
+ self.lscope[name] = val
970
+ self.local_defs[name] = val
867
971
  self.scf_stack.append(node)
868
972
  self.visit_compound_statement(node.body)
869
973
  self.scf_stack.pop()
@@ -871,12 +975,14 @@ class CodeGenerator(ast.NodeVisitor):
871
975
  yields = []
872
976
  for name in loop_defs:
873
977
  if name in liveins:
874
- yields.append(loop_defs[name])
875
- self.builder.create_yield_op([y.handle for y in yields])
978
+ loop_defs[name]._flatten_ir(yields)
979
+
980
+ self.builder.create_yield_op(yields)
876
981
 
877
982
  # WhileOp defines new values, update the symbol table (lscope, local_defs)
878
- for i, name in enumerate(names):
879
- new_def = language.core.tensor(while_op.get_result(i), ret_types[i])
983
+ result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
984
+ result_vals = unflatten_ir_values(result_handles, init_fe_tys)
985
+ for name, new_def in zip(names, result_vals):
880
986
  self.lscope[name] = new_def
881
987
  self.local_defs[name] = new_def
882
988
 
@@ -884,7 +990,7 @@ class CodeGenerator(ast.NodeVisitor):
884
990
  assert False, "Not implemented"
885
991
  ast.NodeVisitor.generic_visit(self, stmt)
886
992
 
887
- def visit_Subscript(self, node):
993
+ def visit_Subscript_Load(self, node):
888
994
  assert node.ctx.__class__.__name__ == "Load"
889
995
  lhs = self.visit(node.value)
890
996
  slices = self.visit(node.slice)
@@ -892,6 +998,16 @@ class CodeGenerator(ast.NodeVisitor):
892
998
  return lhs.__getitem__(slices, _builder=self.builder)
893
999
  return lhs[slices]
894
1000
 
1001
+ def visit_Subscript_Store(self, node, value):
1002
+ assert node.ctx.__class__.__name__ == "Store"
1003
+ lhs = self.visit(node.value)
1004
+ slices = self.visit(node.slice)
1005
+ assert isinstance(lhs, language.tuple)
1006
+ lhs.__setitem__(slices, value)
1007
+
1008
+ def visit_Subscript(self, node):
1009
+ return self.visit_Subscript_Load(node)
1010
+
895
1011
  def visit_ExtSlice(self, node):
896
1012
  return [self.visit(dim) for dim in node.dims]
897
1013
 
@@ -910,6 +1026,8 @@ class CodeGenerator(ast.NodeVisitor):
910
1026
  return
911
1027
  num_stages = None
912
1028
  loop_unroll_factor = None
1029
+ disallow_acc_multi_buffer = False
1030
+ flatten = False
913
1031
  if IteratorClass is language.range:
914
1032
  iterator = IteratorClass(*iter_args, **iter_kwargs)
915
1033
  # visit iterator arguments
@@ -920,6 +1038,8 @@ class CodeGenerator(ast.NodeVisitor):
920
1038
  step = iterator.step
921
1039
  num_stages = iterator.num_stages
922
1040
  loop_unroll_factor = iterator.loop_unroll_factor
1041
+ disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
1042
+ flatten = iterator.flatten
923
1043
  elif IteratorClass is range:
924
1044
  # visit iterator arguments
925
1045
  # note: only `range` iterator is supported now
@@ -935,14 +1055,14 @@ class CodeGenerator(ast.NodeVisitor):
935
1055
  step = constexpr(-step.value)
936
1056
  negative_step = True
937
1057
  lb, ub = ub, lb
938
- lb = language.semantic.to_tensor(lb, self.builder)
939
- ub = language.semantic.to_tensor(ub, self.builder)
940
- step = language.semantic.to_tensor(step, self.builder)
1058
+ lb = semantic.to_tensor(lb, self.builder)
1059
+ ub = semantic.to_tensor(ub, self.builder)
1060
+ step = semantic.to_tensor(step, self.builder)
941
1061
  # induction variable type
942
1062
  if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
943
1063
  raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
944
- iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
945
- iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype)
1064
+ iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype)
1065
+ iv_type = semantic.integer_promote_impl(iv_type, step.dtype)
946
1066
  iv_ir_type = iv_type.to_ir(self.builder)
947
1067
  iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
948
1068
  # lb/ub/step might be constexpr, we need to cast them to tensor
@@ -987,34 +1107,47 @@ class CodeGenerator(ast.NodeVisitor):
987
1107
 
988
1108
  # create ForOp
989
1109
  self._set_insertion_point_and_loc(ip, last_loc)
990
- for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
991
- if num_stages is not None:
1110
+ init_handles = flatten_values_to_ir(init_args)
1111
+ init_tys = [v.type for v in init_args]
1112
+ for_op = self.builder.create_for_op(lb, ub, step, init_handles)
1113
+ if _unwrap_if_constexpr(num_stages) is not None:
992
1114
  for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
993
- if loop_unroll_factor is not None:
1115
+ if _unwrap_if_constexpr(loop_unroll_factor) is not None:
994
1116
  for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
1117
+ if disallow_acc_multi_buffer:
1118
+ for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
1119
+ if flatten:
1120
+ for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
995
1121
 
996
1122
  self.scf_stack.append(node)
997
- self.builder.set_insertion_point_to_start(for_op.get_body(0))
1123
+ for_op_body = for_op.get_body(0)
1124
+ self.builder.set_insertion_point_to_start(for_op_body)
998
1125
  # reset local scope to not pick up local defs from the previous dry run.
999
1126
  self.lscope = liveins.copy()
1000
1127
  self.local_defs = {}
1001
- for i, name in enumerate(names):
1002
- self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type))
1128
+ block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
1129
+ block_args = unflatten_ir_values(block_handles, init_tys)
1130
+ for name, val in zip(names, block_args):
1131
+ self.set_value(name, val)
1003
1132
  self.visit_compound_statement(node.body)
1004
1133
  self.scf_stack.pop()
1005
1134
  yields = []
1006
1135
  for name in self.local_defs:
1007
1136
  if name in liveins:
1008
- yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder))
1137
+ local = self.local_defs[name]
1138
+ if isinstance(local, constexpr):
1139
+ local = semantic.to_tensor(local, self.builder)
1140
+ yields.append(local)
1009
1141
 
1010
1142
  # create YieldOp
1011
1143
  if len(yields) > 0:
1012
- self.builder.create_yield_op([y.handle for y in yields])
1013
- for_op_region = for_op.get_body(0).get_parent()
1144
+ yield_handles = flatten_values_to_ir(yields)
1145
+ self.builder.create_yield_op(yield_handles)
1146
+ for_op_region = for_op_body.get_parent()
1014
1147
  assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
1015
1148
 
1016
1149
  # update induction variable with actual value, and replace all uses
1017
- self.builder.set_insertion_point_to_start(for_op.get_body(0))
1150
+ self.builder.set_insertion_point_to_start(for_op_body)
1018
1151
  iv = for_op.get_induction_var()
1019
1152
  if negative_step:
1020
1153
  iv = self.builder.create_sub(ub, iv)
@@ -1023,8 +1156,10 @@ class CodeGenerator(ast.NodeVisitor):
1023
1156
  self.set_value(node.target.id, language.core.tensor(iv, iv_type))
1024
1157
 
1025
1158
  # update lscope & local_defs (ForOp defines new values)
1026
- for i, name in enumerate(names):
1027
- self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type))
1159
+ result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
1160
+ result_values = unflatten_ir_values(result_handles, init_tys)
1161
+ for name, val in zip(names, result_values):
1162
+ self.set_value(name, val)
1028
1163
 
1029
1164
  for stmt in node.orelse:
1030
1165
  assert False, "Don't know what to do with else after for"
@@ -1034,7 +1169,7 @@ class CodeGenerator(ast.NodeVisitor):
1034
1169
  lower = self.visit(node.lower)
1035
1170
  upper = self.visit(node.upper)
1036
1171
  step = self.visit(node.step)
1037
- return slice(lower, upper, step)
1172
+ return language.slice(lower, upper, step)
1038
1173
 
1039
1174
  def visit_Index(self, node):
1040
1175
  return self.visit(node.value)
@@ -1050,24 +1185,28 @@ class CodeGenerator(ast.NodeVisitor):
1050
1185
  def call_JitFunction(self, fn: JITFunction, args, kwargs):
1051
1186
  args = inspect.getcallargs(fn.fn, *args, **kwargs)
1052
1187
  args = [args[name] for name in fn.arg_names]
1053
- args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args]
1054
- # generate function def
1055
- attributes = {}
1056
- constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
1057
- constants = {i: args[i] for i in constexprs}
1058
- # generate call
1059
- args = [None if i in constexprs else arg for i, arg in enumerate(args)]
1060
- arg_vals = [arg.handle for arg in args if arg is not None]
1061
- arg_types = [arg.type for arg in args if arg is not None]
1062
- fn_name = mangle_fn(fn.__name__, arg_types, constants)
1188
+ for i, arg in enumerate(args):
1189
+ if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
1190
+ args[i] = language.core.constexpr(arg)
1191
+ args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
1192
+ args_cst = {path: get_iterable_path(args, path) for path in args_cst}
1193
+ args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1194
+ args_val = [get_iterable_path(args, path) for path in args_path]
1195
+ # mangle
1196
+ fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
1063
1197
  # generate function def if necessary
1064
1198
  if not self.module.has_function(fn_name):
1065
- prototype = language.function_type([], arg_types)
1066
1199
  gscope = fn.__globals__
1067
1200
  # If the callee is not set, we use the same debug setting as the caller
1068
1201
  file_name, begin_line = get_jit_fn_file_line(fn)
1069
- generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
1070
- jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
1202
+ arg_types = [
1203
+ language.core.constexpr if arg is None or isinstance(arg,
1204
+ (bool, int, language.core.dtype)) else arg.type
1205
+ for arg in args
1206
+ ]
1207
+ prototype = ASTFunction([], arg_types, args_cst, dict())
1208
+ generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn,
1209
+ function_name=fn_name, function_types=self.function_ret_types,
1071
1210
  noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
1072
1211
  options=self.builder.options, codegen_fns=self.builder.codegen_fns,
1073
1212
  module_map=self.builder.module_map)
@@ -1082,17 +1221,12 @@ class CodeGenerator(ast.NodeVisitor):
1082
1221
  else:
1083
1222
  callee_ret_type = self.function_ret_types[fn_name]
1084
1223
  symbol = self.module.get_function(fn_name)
1085
- call_op = self.builder.call(symbol, arg_vals)
1086
- if call_op.get_num_results() == 0 or callee_ret_type is None:
1224
+ args_val = [arg.handle for arg in args_val]
1225
+ call_op = self.builder.call(symbol, args_val)
1226
+ if callee_ret_type == language.void:
1087
1227
  return None
1088
- elif call_op.get_num_results() == 1:
1089
- return tensor(call_op.get_result(0), callee_ret_type)
1090
- else:
1091
- # should return a tuple of tl.tensor
1092
- results = []
1093
- for i in range(call_op.get_num_results()):
1094
- results.append(tensor(call_op.get_result(i), callee_ret_type[i]))
1095
- return tuple(results)
1228
+ handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
1229
+ return next(unflatten_ir_values(handles, [callee_ret_type]))
1096
1230
 
1097
1231
  def visit_Call(self, node):
1098
1232
  fn = _unwrap_if_constexpr(self.visit(node.func))
@@ -1102,6 +1236,7 @@ class CodeGenerator(ast.NodeVisitor):
1102
1236
 
1103
1237
  kws = dict(self.visit(keyword) for keyword in node.keywords)
1104
1238
  args = [self.visit(arg) for arg in node.args]
1239
+ args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1105
1240
  if isinstance(fn, JITFunction):
1106
1241
  _check_fn_args(node, fn, args)
1107
1242
  return self.call_JitFunction(fn, args, kws)
@@ -1111,7 +1246,11 @@ class CodeGenerator(ast.NodeVisitor):
1111
1246
  if '_generator' in sig.parameters:
1112
1247
  extra_kwargs['_generator'] = self
1113
1248
  try:
1114
- return fn(*args, **extra_kwargs, **kws)
1249
+ ret = fn(*args, **extra_kwargs, **kws)
1250
+ # builtin functions return plain tuples for readability
1251
+ if isinstance(ret, tuple):
1252
+ ret = language.tuple(ret)
1253
+ return ret
1115
1254
  except Exception as e:
1116
1255
  # Normally when we raise a CompilationError, we raise it as
1117
1256
  # `from None`, because the original fileline from the exception
@@ -1123,7 +1262,8 @@ class CodeGenerator(ast.NodeVisitor):
1123
1262
 
1124
1263
  if fn in self.builtin_namespace.values():
1125
1264
  args = map(_unwrap_if_constexpr, args)
1126
- return fn(*args, **kws)
1265
+ ret = fn(*args, **kws)
1266
+ return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret
1127
1267
 
1128
1268
  def visit_Constant(self, node):
1129
1269
  return constexpr(node.value)
@@ -1142,21 +1282,10 @@ class CodeGenerator(ast.NodeVisitor):
1142
1282
 
1143
1283
  _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
1144
1284
 
1145
- if sys.version_info < (3, 8):
1146
-
1147
- def visit_NameConstant(self, node):
1148
- return constexpr(node.value)
1149
-
1150
- def visit_Num(self, node):
1151
- return constexpr(node.n)
1152
-
1153
- def visit_Str(self, node):
1154
- return constexpr(ast.literal_eval(node))
1155
-
1156
1285
  def visit_Attribute(self, node):
1157
1286
  lhs = self.visit(node.value)
1158
1287
  if _is_triton_tensor(lhs) and node.attr == "T":
1159
- return language.semantic.permute(lhs, (1, 0), builder=self.builder)
1288
+ return semantic.permute(lhs, (1, 0), builder=self.builder)
1160
1289
  return getattr(lhs, node.attr)
1161
1290
 
1162
1291
  def visit_Expr(self, node):
@@ -1257,46 +1386,20 @@ class CodeGenerator(ast.NodeVisitor):
1257
1386
  }
1258
1387
 
1259
1388
 
1260
- def kernel_suffix(signature, specialization):
1261
- # suffix format:
1262
- # <argid><'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8>
1263
- suffix = ''
1264
- for i, _ in enumerate(signature):
1265
- suffix += str(i)
1266
- if i in specialization.equal_to_1:
1267
- suffix += 'c'
1268
- if i in specialization.divisibility_16:
1269
- suffix += 'd'
1270
- return suffix
1271
-
1272
-
1273
- def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
1274
- attrs = specialization.attrs
1275
- # create kernel prototype
1276
- cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
1277
- constants = {cst_key(key): value for key, value in specialization.constants.items()}
1278
- # visit kernel AST
1279
- gscope = fn.__globals__.copy()
1280
- function_name = fn.repr(specialization)
1281
- tys = list(specialization.signature.values())
1282
- new_constants = attrs.get_constants()
1283
- for k in new_constants:
1284
- if k in tys and tys[k] == "i1" and new_constants[k] == 1:
1285
- new_constants[k] = True
1286
-
1287
- new_attrs = attrs.filter_out_constants()
1288
- fn_attrs = new_attrs.get_fn_attrs()
1289
- all_constants = constants.copy()
1290
- all_constants.update(new_constants)
1291
- arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
1389
+ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
1390
+ arg_types = list(map(str_to_ty, src.signature.values()))
1391
+ prototype = ASTFunction([], arg_types, src.constants, src.attrs)
1292
1392
  file_name, begin_line = get_jit_fn_file_line(fn)
1293
-
1294
- prototype = language.function_type([], arg_types)
1295
- generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
1296
- jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name,
1297
- begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
1393
+ # query function representation
1394
+ from collections import namedtuple
1395
+ leaves = filter(lambda v: len(v) == 1, src.constants)
1396
+ constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
1397
+ signature = src.signature
1398
+ proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
1399
+ generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(proxy), jit_fn=fn,
1400
+ is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1401
+ codegen_fns=codegen_fns, module_map=module_map)
1298
1402
  generator.visit(fn.parse())
1299
-
1300
1403
  ret = generator.module
1301
1404
  # module takes ownership of the context
1302
1405
  ret.context = context