triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,6 @@
1
1
  import ast
2
+ import builtins
3
+ import contextlib
2
4
  import copy
3
5
  import inspect
4
6
  import re
@@ -11,11 +13,10 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable,
11
13
 
12
14
  from .. import knobs, language
13
15
  from .._C.libtriton import ir, gluon_ir
14
- from ..language import constexpr, str_to_ty, tensor
16
+ from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
15
17
  from ..language.core import _unwrap_if_constexpr, base_value, base_type
16
- from ..runtime.jit import get_jit_fn_file_line, get_full_name
17
18
  # ideally we wouldn't need any runtime component
18
- from ..runtime import JITFunction
19
+ from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
19
20
  from .._utils import find_paths_if, get_iterable_path, set_iterable_path
20
21
 
21
22
  from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
@@ -28,7 +29,7 @@ def check_identifier_legality(name, type):
28
29
  return name
29
30
 
30
31
 
31
- def mangle_fn(name, arg_tys, constants):
32
+ def mangle_fn(name, arg_tys, constants, caller_context):
32
33
  # doesn't mangle ret type, which must be a function of arg tys
33
34
  mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
34
35
  mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
@@ -37,6 +38,8 @@ def mangle_fn(name, arg_tys, constants):
37
38
  # [ and ] are not allowed in LLVM identifiers
38
39
  mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
39
40
  ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
41
+ if caller_context is not None:
42
+ ret += caller_context.mangle()
40
43
  return ret
41
44
 
42
45
 
@@ -49,7 +52,7 @@ def _is_triton_tensor(o: Any) -> bool:
49
52
 
50
53
 
51
54
  def _is_constexpr(o: Any) -> bool:
52
- return o is None or isinstance(o, (constexpr, language.core.dtype, JITFunction))
55
+ return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
53
56
 
54
57
 
55
58
  def _is_non_scalar_tensor(o: Any) -> bool:
@@ -106,6 +109,17 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
106
109
  _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
107
110
 
108
111
 
112
+ def _clone_triton_value(val):
113
+ handles = []
114
+ val._flatten_ir(handles)
115
+ clone, _ = val.type._unflatten_ir(handles, 0)
116
+ return clone
117
+
118
+
119
+ def _clone_scope(scope):
120
+ return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
121
+
122
+
109
123
  class enter_sub_region:
110
124
 
111
125
  def __init__(self, generator):
@@ -113,8 +127,8 @@ class enter_sub_region:
113
127
 
114
128
  def __enter__(self):
115
129
  # record lscope & local_defs in the parent scope
116
- self.liveins = self.generator.lscope.copy()
117
- self.prev_defs = self.generator.local_defs.copy()
130
+ self.liveins = _clone_scope(self.generator.lscope)
131
+ self.prev_defs = _clone_scope(self.generator.local_defs)
118
132
  self.generator.local_defs = {}
119
133
  self.insert_block = self.generator.builder.get_insertion_block()
120
134
  self.insert_point = self.generator.builder.get_insertion_point()
@@ -136,9 +150,9 @@ class ContainsReturnChecker(ast.NodeVisitor):
136
150
  return any(self.visit(s) for s in body)
137
151
 
138
152
  def _visit_function(self, fn) -> bool:
139
- # no need to check within the function as it won't cause an early return.
140
- # If the function itself has unstructured control flow we may not be able to inline it causing poor performance.
141
- # We should check for this and fail or emit a warning.
153
+ # No need to check within the function as it won't cause an early return.
154
+ # If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
155
+ # we should check for this and emit a warning.
142
156
  return False
143
157
 
144
158
  def generic_visit(self, node) -> bool:
@@ -280,11 +294,12 @@ class BoundJITMethod:
280
294
 
281
295
  class CodeGenerator(ast.NodeVisitor):
282
296
 
283
- def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
284
- module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
285
- file_name: Optional[str] = None, begin_line=0):
297
+ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
298
+ module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
299
+ noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
286
300
  self.context = context
287
- if jit_fn.is_gluon():
301
+ self.is_gluon = is_gluon
302
+ if is_gluon:
288
303
  from triton.experimental.gluon.language._semantic import GluonSemantic
289
304
  self.builder = gluon_ir.GluonOpBuilder(context)
290
305
  self.semantic = GluonSemantic(self.builder)
@@ -292,6 +307,8 @@ class CodeGenerator(ast.NodeVisitor):
292
307
  from triton.language.semantic import TritonSemantic
293
308
  self.builder = ir.builder(context)
294
309
  self.semantic = TritonSemantic(self.builder)
310
+
311
+ self.name_loc_as_prefix = None
295
312
  self.file_name = file_name
296
313
  # node.lineno starts from 1, so we need to subtract 1
297
314
  self.begin_line = begin_line - 1
@@ -328,6 +345,7 @@ class CodeGenerator(ast.NodeVisitor):
328
345
  self.is_kernel = is_kernel
329
346
  self.cur_node = None
330
347
  self.noinline = noinline
348
+ self.caller_context = caller_context
331
349
  self.scf_stack = []
332
350
  self.ret_type = None
333
351
  # SSA-construction
@@ -378,7 +396,7 @@ class CodeGenerator(ast.NodeVisitor):
378
396
  val is absent,
379
397
  name in self.builtin_namespace, #
380
398
  type(val) is ModuleType, #
381
- isinstance(val, JITFunction), #
399
+ isinstance(val, JITCallable), #
382
400
  getattr(val, "__triton_builtin__", False), #
383
401
  getattr(val, "__triton_aggregate__", False), #
384
402
  getattr(val, "__module__", "").startswith("triton.language"), #
@@ -414,6 +432,21 @@ class CodeGenerator(ast.NodeVisitor):
414
432
 
415
433
  return name_lookup
416
434
 
435
+ @contextlib.contextmanager
436
+ def _name_loc_prefix(self, prefix):
437
+ self.name_loc_as_prefix = prefix
438
+ yield
439
+ self.name_loc_as_prefix = None
440
+
441
+ def _maybe_set_loc_to_name(self, val, name):
442
+ if isinstance(val, (ir.value, ir.block_argument)):
443
+ val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
444
+ elif _is_triton_value(val):
445
+ handles = []
446
+ val._flatten_ir(handles)
447
+ for handle in handles:
448
+ handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
449
+
417
450
  def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
418
451
  ''' This function:
419
452
  called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
@@ -435,6 +468,43 @@ class CodeGenerator(ast.NodeVisitor):
435
468
  self.builder.restore_insertion_point(ip)
436
469
  self.builder.set_loc(loc)
437
470
 
471
+ def _find_carries(self, node, liveins):
472
+ # create loop body block
473
+ block = self.builder.create_block()
474
+ self.builder.set_insertion_point_to_start(block)
475
+ # dry visit loop body
476
+ self.scf_stack.append(node)
477
+ self.visit_compound_statement(node.body)
478
+ self.scf_stack.pop()
479
+ block.erase()
480
+
481
+ # If a variable (name) has changed value within the loop, then it's
482
+ # a loop-carried variable. (The new and old value must be of the
483
+ # same type)
484
+ init_tys = []
485
+ init_handles = []
486
+ names = []
487
+
488
+ for name, live_val in liveins.items():
489
+ if _is_triton_value(live_val):
490
+ loop_val = self.lscope[name]
491
+ self._verify_loop_carried_variable(name, loop_val, live_val)
492
+
493
+ live_handles = flatten_values_to_ir([live_val])
494
+ loop_handles = flatten_values_to_ir([loop_val])
495
+ if live_handles != loop_handles:
496
+ names.append(name)
497
+ init_tys.append(live_val.type)
498
+ init_handles.extend(live_handles)
499
+ else:
500
+ assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
501
+
502
+ # reset local scope to not pick up local defs from the dry run.
503
+ self.lscope = liveins.copy()
504
+ self.local_defs = {}
505
+
506
+ return names, init_handles, init_tys
507
+
438
508
  #
439
509
  # AST visitor
440
510
  #
@@ -458,6 +528,21 @@ class CodeGenerator(ast.NodeVisitor):
458
528
  elts = language.tuple([self.visit(elt) for elt in node.elts])
459
529
  return elts
460
530
 
531
+ def visit_ListComp(self, node: ast.ListComp):
532
+ if len(node.generators) != 1:
533
+ raise ValueError("nested comprehensions are not supported")
534
+
535
+ comp = node.generators[0]
536
+ iter = self.visit(comp.iter)
537
+ if not isinstance(iter, tl_tuple):
538
+ raise NotImplementedError("only tuple comprehensions are supported")
539
+
540
+ results = []
541
+ for item in iter:
542
+ self.set_value(comp.target.id, item)
543
+ results.append(self.visit(node.elt))
544
+ return tl_tuple(results)
545
+
461
546
  # By design, only non-kernel functions can return
462
547
  def visit_Return(self, node):
463
548
  ret_value = self.visit(node.value)
@@ -522,8 +607,11 @@ class CodeGenerator(ast.NodeVisitor):
522
607
  self.module.push_back(self.fn)
523
608
  entry = self.fn.add_entry_block()
524
609
  arg_values = self.prototype.deserialize(self.fn)
610
+ if self.caller_context is not None:
611
+ self.caller_context.initialize_callee(self.fn, self.builder)
525
612
  # bind arguments to symbols
526
613
  for arg_name, arg_value in zip(arg_names, arg_values):
614
+ self._maybe_set_loc_to_name(arg_value, arg_name)
527
615
  self.set_value(arg_name, arg_value)
528
616
  insert_pt = self.builder.get_insertion_block()
529
617
  self.builder.set_insertion_point_to_start(entry)
@@ -583,9 +671,7 @@ class CodeGenerator(ast.NodeVisitor):
583
671
  self.assignTarget(target, value.values[i])
584
672
  return
585
673
  if isinstance(target, ast.Attribute):
586
- base = self.visit(target.value)
587
- setattr(base, target.attr, value)
588
- return
674
+ raise NotImplementedError("Attribute assignment is not supported in triton")
589
675
  assert isinstance(target, ast.Name)
590
676
  self.set_value(self.visit(target), value)
591
677
 
@@ -602,10 +688,15 @@ class CodeGenerator(ast.NodeVisitor):
602
688
  value = self.semantic.to_tensor(value)
603
689
  return value
604
690
 
605
- values = _sanitize_value(self.visit(node.value))
606
691
  targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
607
692
  assert len(targets) == 1
608
- self.assignTarget(targets[0], values)
693
+ target = targets[0]
694
+ if isinstance(target, ast.Name):
695
+ with self._name_loc_prefix(target.id):
696
+ values = _sanitize_value(self.visit(node.value))
697
+ else:
698
+ values = _sanitize_value(self.visit(node.value))
699
+ self.assignTarget(target, values)
609
700
 
610
701
  def visit_AugAssign(self, node):
611
702
  lhs = copy.deepcopy(node.target)
@@ -671,8 +762,10 @@ class CodeGenerator(ast.NodeVisitor):
671
762
  self.visit_compound_statement(node.body)
672
763
  then_block = self.builder.get_insertion_block()
673
764
  then_defs = self.local_defs.copy()
765
+ then_vals = self.lscope.copy()
674
766
  # else block
675
767
  else_defs = {}
768
+ else_vals = liveins.copy()
676
769
  if node.orelse:
677
770
  self.builder.set_insertion_point_to_start(else_block)
678
771
  self.lscope = liveins.copy()
@@ -680,26 +773,29 @@ class CodeGenerator(ast.NodeVisitor):
680
773
  self.visit_compound_statement(node.orelse)
681
774
  else_defs = self.local_defs.copy()
682
775
  else_block = self.builder.get_insertion_block()
776
+ else_vals = self.lscope.copy()
683
777
 
684
778
  # update block arguments
685
779
  names = []
686
780
  # variables in livein whose value is updated in `if`
687
- for name in liveins:
781
+ for name, value in liveins.items():
782
+ # livein variable changed value in either then or else
783
+ if not _is_triton_value(value):
784
+ continue
785
+ then_handles = flatten_values_to_ir([then_vals[name]])
786
+ else_handles = flatten_values_to_ir([else_vals[name]])
787
+ if then_handles == else_handles:
788
+ continue
789
+ names.append(name)
790
+ then_defs[name] = then_vals[name]
791
+ else_defs[name] = else_vals[name]
688
792
  # check type
689
793
  for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
690
- if name in defs:
691
- type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721
692
- assert type_equal and defs[name].type == liveins[name].type, \
693
- f'initial value for `{name}` is of type {liveins[name]}, '\
694
- f'but the {block_name} block redefines it as {defs[name]}'
695
- if name in then_defs or name in else_defs:
696
- names.append(name)
697
- # variable defined in then but not in else
698
- if name in then_defs and name not in else_defs:
699
- else_defs[name] = liveins[name]
700
- # variable defined in else but not in then
701
- if name in else_defs and name not in then_defs:
702
- then_defs[name] = liveins[name]
794
+ type_equal = type(defs[name]) == type(value) # noqa: E721
795
+ assert type_equal and defs[name].type == value.type, \
796
+ f'initial value for `{name}` is of type {value}, '\
797
+ f'but the {block_name} block redefines it as {defs[name]}'
798
+
703
799
  # variables that are both in then and else but not in liveins
704
800
  # TODO: could probably be cleaned up
705
801
  for name in sorted(then_defs.keys() & else_defs.keys()):
@@ -766,6 +862,8 @@ class CodeGenerator(ast.NodeVisitor):
766
862
  self.visit_then_else_blocks(node, liveins, then_block, else_block)
767
863
  # create if op
768
864
  then_handles = flatten_values_to_ir(then_defs[name] for name in names)
865
+ for name, val in zip(names, then_handles):
866
+ self._maybe_set_loc_to_name(val, name)
769
867
  self._set_insertion_point_and_loc(ip, last_loc)
770
868
  if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
771
869
  then_block.merge_block_before(if_op.get_then_block())
@@ -779,6 +877,8 @@ class CodeGenerator(ast.NodeVisitor):
779
877
  self.builder.set_insertion_point_to_end(if_op.get_else_block())
780
878
  if len(names) > 0:
781
879
  else_handles = flatten_values_to_ir(else_defs[name] for name in names)
880
+ for name, val in zip(names, else_handles):
881
+ self._maybe_set_loc_to_name(val, name)
782
882
  self.builder.create_yield_op(else_handles)
783
883
  # update values
784
884
  res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
@@ -799,13 +899,10 @@ class CodeGenerator(ast.NodeVisitor):
799
899
  % ast.unparse(node.test))
800
900
  cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
801
901
  cond = cond.to(language.int1, _semantic=self.semantic)
802
- contains_return = ContainsReturnChecker(self.gscope).visit(node)
803
- if contains_return:
902
+ if ContainsReturnChecker(self.gscope).visit(node):
804
903
  if self.scf_stack:
805
904
  raise self._unsupported(
806
- node, "Cannot have `return` statements inside `while` or `for` statements in triton "
807
- "(note that this also applies to `return` statements that are inside functions "
808
- "transitively called from within `while`/`for` statements)")
905
+ node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
809
906
  self.visit_if_top_level(cond, node)
810
907
  else:
811
908
  self.visit_if_scf(cond, node)
@@ -874,6 +971,37 @@ class CodeGenerator(ast.NodeVisitor):
874
971
  else:
875
972
  return self.visit(node.orelse)
876
973
 
974
+ def visit_With(self, node):
975
+ # Lower `with` statements by constructing context managers and calling their enter/exit hooks
976
+ # Instantiate each context manager with builder injection
977
+ if len(node.items) == 1: # Handle async_task
978
+ context = node.items[0].context_expr
979
+ withitemClass = self.visit(context.func)
980
+ if withitemClass == language.async_task:
981
+ args = [self.visit(arg) for arg in context.args]
982
+ with withitemClass(*args, _builder=self.builder):
983
+ self.visit_compound_statement(node.body)
984
+ return
985
+
986
+ cm_list = []
987
+ for item in node.items:
988
+ call = item.context_expr
989
+ fn = self.visit(call.func)
990
+ args = [self.visit(arg) for arg in call.args]
991
+ kws = dict(self.visit(kw) for kw in call.keywords)
992
+ cm = fn(*args, _semantic=self.semantic, **kws)
993
+ cm_list.append(cm)
994
+ for cm, item in zip(cm_list, node.items):
995
+ res = cm.__enter__()
996
+ if item.optional_vars is not None:
997
+ var_name = self.visit(item.optional_vars)
998
+ self.set_value(var_name, res)
999
+ if ContainsReturnChecker(self.gscope).visit(node):
1000
+ raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
1001
+ self.visit_compound_statement(node.body)
1002
+ for cm in reversed(cm_list):
1003
+ cm.__exit__(None, None, None)
1004
+
877
1005
  def visit_Pass(self, node):
878
1006
  pass
879
1007
 
@@ -918,9 +1046,10 @@ class CodeGenerator(ast.NodeVisitor):
918
1046
  }
919
1047
 
920
1048
  def _verify_loop_carried_variable(self, name, loop_val, live_val):
921
- assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
922
- assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
923
- assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type'
1049
+ assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
1050
+ assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
1051
+ assert type(loop_val) is type(live_val), (
1052
+ f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
924
1053
  assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
925
1054
  f'Loop-carried variable {name} has initial type {live_val.type} '\
926
1055
  f'but is re-assigned to {loop_val.type} in loop! '\
@@ -929,49 +1058,14 @@ class CodeGenerator(ast.NodeVisitor):
929
1058
  def visit_withitem(self, node):
930
1059
  return self.visit(node.context_expr)
931
1060
 
932
- def visit_With(self, node):
933
- assert len(node.items) == 1
934
- context = node.items[0].context_expr
935
- withitemClass = self.visit(context.func)
936
- if withitemClass == language.async_task:
937
- args = [self.visit(arg) for arg in context.args]
938
- with withitemClass(*args, _builder=self.builder):
939
- self.visit_compound_statement(node.body)
940
- else:
941
- self.visit_compound_statement(node.body)
942
-
943
1061
  def visit_While(self, node):
944
1062
  with enter_sub_region(self) as sr:
945
1063
  liveins, insert_block = sr
946
1064
  ip, last_loc = self._get_insertion_point_and_loc()
947
1065
 
948
- # loop body (the after region)
949
- # loop_block = self.builder.create_block()
950
- dummy = self.builder.create_block()
951
- self.builder.set_insertion_point_to_start(dummy)
952
- self.scf_stack.append(node)
953
- self.visit_compound_statement(node.body)
954
- self.scf_stack.pop()
955
- loop_defs = self.local_defs
956
- dummy.erase()
957
-
958
- # collect loop-carried values
959
- names = []
960
- init_args = []
961
- for name in loop_defs:
962
- if name in liveins:
963
- # We should not def new constexpr
964
- loop_val = loop_defs[name]
965
- live_val = liveins[name]
966
- self._verify_loop_carried_variable(name, loop_val, live_val)
967
-
968
- # these are loop-carried values
969
- names.append(name)
970
- init_args.append(live_val)
1066
+ names, init_handles, init_fe_tys = self._find_carries(node, liveins)
971
1067
 
972
- init_handles = flatten_values_to_ir(init_args)
973
1068
  init_tys = [h.get_type() for h in init_handles]
974
- init_fe_tys = [a.type for a in init_args]
975
1069
  self._set_insertion_point_and_loc(ip, last_loc)
976
1070
  while_op = self.builder.create_while_op(init_tys, init_handles)
977
1071
  # merge the condition region
@@ -982,7 +1076,12 @@ class CodeGenerator(ast.NodeVisitor):
982
1076
  for name, val in zip(names, condition_args):
983
1077
  self.lscope[name] = val
984
1078
  self.local_defs[name] = val
1079
+ self._maybe_set_loc_to_name(val, name)
985
1080
  cond = self.visit(node.test)
1081
+ if isinstance(cond, language.condition):
1082
+ if cond.disable_licm:
1083
+ while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
1084
+ cond = cond.condition
986
1085
  self.builder.set_insertion_point_to_end(before_block)
987
1086
  # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
988
1087
  self.builder.create_condition_op(cond.handle, block_args)
@@ -996,16 +1095,13 @@ class CodeGenerator(ast.NodeVisitor):
996
1095
  for name, val in zip(names, body_args):
997
1096
  self.lscope[name] = val
998
1097
  self.local_defs[name] = val
1098
+ self._maybe_set_loc_to_name(val, name)
999
1099
  self.scf_stack.append(node)
1000
1100
  self.visit_compound_statement(node.body)
1001
1101
  self.scf_stack.pop()
1002
- loop_defs = self.local_defs
1003
- yields = []
1004
- for name in loop_defs:
1005
- if name in liveins:
1006
- loop_defs[name]._flatten_ir(yields)
1007
1102
 
1008
- self.builder.create_yield_op(yields)
1103
+ yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
1104
+ self.builder.create_yield_op(yield_handles)
1009
1105
 
1010
1106
  # WhileOp defines new values, update the symbol table (lscope, local_defs)
1011
1107
  result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
@@ -1013,6 +1109,7 @@ class CodeGenerator(ast.NodeVisitor):
1013
1109
  for name, new_def in zip(names, result_vals):
1014
1110
  self.lscope[name] = new_def
1015
1111
  self.local_defs[name] = new_def
1112
+ self._maybe_set_loc_to_name(new_def, name)
1016
1113
 
1017
1114
  for stmt in node.orelse:
1018
1115
  assert False, "Not implemented"
@@ -1022,16 +1119,12 @@ class CodeGenerator(ast.NodeVisitor):
1022
1119
  assert isinstance(node.ctx, ast.Load)
1023
1120
  lhs = self.visit(node.value)
1024
1121
  slices = self.visit(node.slice)
1025
- if _is_triton_tensor(lhs):
1026
- return lhs.__getitem__(slices, _semantic=self.semantic)
1122
+ if _is_triton_value(lhs):
1123
+ return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
1027
1124
  return lhs[slices]
1028
1125
 
1029
1126
  def visit_Subscript_Store(self, node, value):
1030
- assert isinstance(node.ctx, ast.Store)
1031
- lhs = self.visit(node.value)
1032
- slices = self.visit(node.slice)
1033
- assert isinstance(lhs, language.tuple)
1034
- lhs.__setitem__(slices, value)
1127
+ raise NotImplementedError("__setitem__ is not supported in triton")
1035
1128
 
1036
1129
  def visit_Subscript(self, node):
1037
1130
  return self.visit_Subscript_Load(node)
@@ -1057,6 +1150,7 @@ class CodeGenerator(ast.NodeVisitor):
1057
1150
  disallow_acc_multi_buffer = False
1058
1151
  flatten = False
1059
1152
  warp_specialize = False
1153
+ disable_licm = False
1060
1154
  if IteratorClass is language.range:
1061
1155
  iterator = IteratorClass(*iter_args, **iter_kwargs)
1062
1156
  # visit iterator arguments
@@ -1070,6 +1164,7 @@ class CodeGenerator(ast.NodeVisitor):
1070
1164
  disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
1071
1165
  flatten = iterator.flatten
1072
1166
  warp_specialize = iterator.warp_specialize
1167
+ disable_licm = iterator.disable_licm
1073
1168
  elif IteratorClass is range:
1074
1169
  # visit iterator arguments
1075
1170
  # note: only `range` iterator is supported now
@@ -1111,34 +1206,10 @@ class CodeGenerator(ast.NodeVisitor):
1111
1206
  liveins, insert_block = sr
1112
1207
  ip, last_loc = self._get_insertion_point_and_loc()
1113
1208
 
1114
- # create loop body block
1115
- block = self.builder.create_block()
1116
- self.builder.set_insertion_point_to_start(block)
1117
- # dry visit loop body
1118
- self.scf_stack.append(node)
1119
- self.visit_compound_statement(node.body)
1120
- self.scf_stack.pop()
1121
- block.erase()
1122
-
1123
- # If a variable (name) is defined in both its parent & itself, then it's
1124
- # a loop-carried variable. (They must be of the same type)
1125
- init_args = []
1126
- yields = []
1127
- names = []
1128
- for name in self.local_defs:
1129
- if name in liveins:
1130
- loop_val = self.local_defs[name]
1131
- live_val = liveins[name]
1132
- self._verify_loop_carried_variable(name, loop_val, live_val)
1133
-
1134
- names.append(name)
1135
- init_args.append(live_val)
1136
- yields.append(loop_val)
1209
+ names, init_handles, init_tys = self._find_carries(node, liveins)
1137
1210
 
1138
1211
  # create ForOp
1139
1212
  self._set_insertion_point_and_loc(ip, last_loc)
1140
- init_handles = flatten_values_to_ir(init_args)
1141
- init_tys = [v.type for v in init_args]
1142
1213
  for_op = self.builder.create_for_op(lb, ub, step, init_handles)
1143
1214
  if _unwrap_if_constexpr(num_stages) is not None:
1144
1215
  for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
@@ -1150,30 +1221,23 @@ class CodeGenerator(ast.NodeVisitor):
1150
1221
  for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
1151
1222
  if warp_specialize:
1152
1223
  for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
1224
+ if disable_licm:
1225
+ for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
1153
1226
 
1154
1227
  self.scf_stack.append(node)
1155
1228
  for_op_body = for_op.get_body(0)
1156
1229
  self.builder.set_insertion_point_to_start(for_op_body)
1157
- # reset local scope to not pick up local defs from the previous dry run.
1158
- self.lscope = liveins.copy()
1159
- self.local_defs = {}
1160
1230
  block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
1161
1231
  block_args = unflatten_ir_values(block_handles, init_tys)
1162
1232
  for name, val in zip(names, block_args):
1233
+ self._maybe_set_loc_to_name(val, name)
1163
1234
  self.set_value(name, val)
1164
1235
  self.visit_compound_statement(node.body)
1165
1236
  self.scf_stack.pop()
1166
- yields = []
1167
- for name in self.local_defs:
1168
- if name in liveins:
1169
- local = self.local_defs[name]
1170
- if isinstance(local, constexpr):
1171
- local = self.semantic.to_tensor(local)
1172
- yields.append(local)
1237
+ yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
1173
1238
 
1174
1239
  # create YieldOp
1175
- if len(yields) > 0:
1176
- yield_handles = flatten_values_to_ir(yields)
1240
+ if len(yield_handles) > 0:
1177
1241
  self.builder.create_yield_op(yield_handles)
1178
1242
  for_op_region = for_op_body.get_parent()
1179
1243
  assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
@@ -1186,12 +1250,14 @@ class CodeGenerator(ast.NodeVisitor):
1186
1250
  iv = self.builder.create_add(iv, lb)
1187
1251
  self.lscope[node.target.id].handle.replace_all_uses_with(iv)
1188
1252
  self.set_value(node.target.id, language.core.tensor(iv, iv_type))
1253
+ self._maybe_set_loc_to_name(iv, node.target.id)
1189
1254
 
1190
1255
  # update lscope & local_defs (ForOp defines new values)
1191
1256
  result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
1192
1257
  result_values = unflatten_ir_values(result_handles, init_tys)
1193
1258
  for name, val in zip(names, result_values):
1194
1259
  self.set_value(name, val)
1260
+ self._maybe_set_loc_to_name(val, name)
1195
1261
 
1196
1262
  for stmt in node.orelse:
1197
1263
  assert False, "Don't know what to do with else after for"
@@ -1214,7 +1280,7 @@ class CodeGenerator(ast.NodeVisitor):
1214
1280
  msg = self.visit(node.msg) if node.msg is not None else ""
1215
1281
  return language.core.device_assert(test, msg, _semantic=self.semantic)
1216
1282
 
1217
- def call_JitFunction(self, fn: JITFunction, args, kwargs):
1283
+ def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
1218
1284
  args = inspect.getcallargs(fn.fn, *args, **kwargs)
1219
1285
  args = [args[name] for name in fn.arg_names]
1220
1286
  for i, arg in enumerate(args):
@@ -1225,7 +1291,8 @@ class CodeGenerator(ast.NodeVisitor):
1225
1291
  args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1226
1292
  args_val = [get_iterable_path(args, path) for path in args_path]
1227
1293
  # mangle
1228
- fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst)
1294
+ caller_context = caller_context or self.caller_context
1295
+ fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
1229
1296
  # generate function def if necessary
1230
1297
  if not self.module.has_function(fn_name):
1231
1298
  # If the callee is not set, we use the same debug setting as the caller
@@ -1240,7 +1307,8 @@ class CodeGenerator(ast.NodeVisitor):
1240
1307
  function_name=fn_name, function_types=self.function_ret_types,
1241
1308
  noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
1242
1309
  options=self.builder.options, codegen_fns=self.builder.codegen_fns,
1243
- module_map=self.builder.module_map)
1310
+ module_map=self.builder.module_map, caller_context=caller_context,
1311
+ is_gluon=self.is_gluon)
1244
1312
  try:
1245
1313
  generator.visit(fn.parse())
1246
1314
  except Exception as e:
@@ -1261,32 +1329,23 @@ class CodeGenerator(ast.NodeVisitor):
1261
1329
  handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
1262
1330
  return next(unflatten_ir_values(handles, [callee_ret_type]))
1263
1331
 
1264
- def visit_Call(self, node):
1265
- fn = _unwrap_if_constexpr(self.visit(node.func))
1266
- if not isinstance(fn, BoundJITMethod):
1267
- static_implementation = self.statically_implemented_functions.get(fn)
1268
- if static_implementation is not None:
1269
- return static_implementation(self, node)
1270
-
1271
- mur = getattr(fn, '_must_use_result', False)
1272
- if mur and getattr(node, '_is_unused', False):
1273
- error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1274
- if isinstance(mur, str):
1275
- error_message.append(mur)
1276
- raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1277
-
1278
- kws = dict(self.visit(keyword) for keyword in node.keywords)
1279
- args = [self.visit(arg) for arg in node.args]
1280
- args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1281
- if isinstance(fn, BoundJITMethod):
1332
+ def call_Function(self, node, fn, args, kws):
1333
+ if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
1282
1334
  args.insert(0, fn.__self__)
1283
1335
  fn = fn.__func__
1284
1336
  if isinstance(fn, JITFunction):
1285
1337
  _check_fn_args(node, fn, args)
1286
1338
  return self.call_JitFunction(fn, args, kws)
1287
- if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
1288
- extra_kwargs = {"_semantic": self.semantic}
1289
- sig = inspect.signature(fn)
1339
+ if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
1340
+ fn, ConstexprFunction):
1341
+ extra_kwargs = dict()
1342
+
1343
+ if isinstance(fn, ConstexprFunction):
1344
+ sig = inspect.signature(fn.__call__)
1345
+ else:
1346
+ sig = inspect.signature(fn)
1347
+ if '_semantic' in sig.parameters:
1348
+ extra_kwargs["_semantic"] = self.semantic
1290
1349
  if '_generator' in sig.parameters:
1291
1350
  extra_kwargs['_generator'] = self
1292
1351
  try:
@@ -1304,12 +1363,45 @@ class CodeGenerator(ast.NodeVisitor):
1304
1363
  # itself). But when calling a function, we raise as `from e` to
1305
1364
  # preserve the traceback of the original error, which may e.g.
1306
1365
  # be in core.py.
1307
- raise CompilationError(self.jit_fn.src, node, None) from e
1366
+ raise CompilationError(self.jit_fn.src, node, str(e)) from e
1308
1367
 
1309
1368
  if fn in self.builtin_namespace.values():
1310
1369
  args = map(_unwrap_if_constexpr, args)
1311
1370
  ret = fn(*args, **kws)
1312
- return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret
1371
+
1372
+ def wrap_constexpr(x):
1373
+ if _is_triton_value(x):
1374
+ return x
1375
+ return constexpr(x)
1376
+
1377
+ if isinstance(ret, (builtins.tuple, language.tuple)):
1378
+ return _apply_to_tuple_values(ret, wrap_constexpr)
1379
+ return wrap_constexpr(ret)
1380
+
1381
+ def call_Method(self, node, fn, fn_self, args, kws):
1382
+ if isinstance(fn, JITFunction):
1383
+ args.insert(0, fn_self)
1384
+ return self.call_Function(node, fn, args, kws)
1385
+
1386
+ def visit_Call(self, node):
1387
+ fn = _unwrap_if_constexpr(self.visit(node.func))
1388
+ if not isinstance(fn, BoundJITMethod):
1389
+ static_implementation = self.statically_implemented_functions.get(fn)
1390
+ if static_implementation is not None:
1391
+ return static_implementation(self, node)
1392
+
1393
+ mur = getattr(fn, '_must_use_result', False)
1394
+ if mur and getattr(node, '_is_unused', False):
1395
+ error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1396
+ if isinstance(mur, str):
1397
+ error_message.append(mur)
1398
+ raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1399
+
1400
+ kws = dict(self.visit(keyword) for keyword in node.keywords)
1401
+ args = [self.visit(arg) for arg in node.args]
1402
+ args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1403
+
1404
+ return self.call_Function(node, fn, args, kws)
1313
1405
 
1314
1406
  def visit_Constant(self, node):
1315
1407
  return constexpr(node.value)
@@ -1373,7 +1465,7 @@ class CodeGenerator(ast.NodeVisitor):
1373
1465
  if _is_triton_tensor(lhs) and node.attr == "T":
1374
1466
  return self.semantic.permute(lhs, (1, 0))
1375
1467
  # NOTE: special case ".value" for BC
1376
- if isinstance(lhs, constexpr) and node.attr != "value":
1468
+ if isinstance(lhs, constexpr) and node.attr not in ("value", "type"):
1377
1469
  lhs = lhs.value
1378
1470
  attr = getattr(lhs, node.attr)
1379
1471
  if _is_triton_value(lhs) and isinstance(attr, JITFunction):
@@ -1417,7 +1509,11 @@ class CodeGenerator(ast.NodeVisitor):
1417
1509
  last_loc = self.builder.get_loc()
1418
1510
  self.cur_node = node
1419
1511
  if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
1420
- self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
1512
+ here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
1513
+ if self.name_loc_as_prefix is not None:
1514
+ self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
1515
+ else:
1516
+ self.builder.set_loc(here_loc)
1421
1517
  last_loc = self.builder.get_loc()
1422
1518
  try:
1423
1519
  ret = super().visit(node)
@@ -1486,9 +1582,16 @@ class CodeGenerator(ast.NodeVisitor):
1486
1582
 
1487
1583
  def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
1488
1584
  arg_types = [None] * len(fn.arg_names)
1489
- for k, v in src.signature.items():
1490
- idx = fn.arg_names.index(k)
1491
- arg_types[idx] = str_to_ty(v)
1585
+ const_iter = iter(src.constants.items())
1586
+ kc, vc = next(const_iter, (None, None))
1587
+
1588
+ for i, (ks, v) in enumerate(src.signature.items()):
1589
+ idx = fn.arg_names.index(ks)
1590
+ cexpr = None
1591
+ if kc is not None and kc[0] == i:
1592
+ cexpr = vc
1593
+ kc, vc = next(const_iter, (None, None))
1594
+ arg_types[idx] = str_to_ty(v, cexpr)
1492
1595
  prototype = ASTFunction([], arg_types, src.constants, src.attrs)
1493
1596
  file_name, begin_line = get_jit_fn_file_line(fn)
1494
1597
  # query function representation
@@ -1499,9 +1602,13 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None)
1499
1602
  proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
1500
1603
  generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
1501
1604
  jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1502
- codegen_fns=codegen_fns, module_map=module_map, module=module)
1605
+ codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
1503
1606
  generator.visit(fn.parse())
1504
- ret = generator.module
1607
+ module = generator.module
1505
1608
  # module takes ownership of the context
1506
- ret.context = context
1507
- return ret
1609
+ module.context = context
1610
+ if not module.verify_with_diagnostics():
1611
+ if not fn.is_gluon():
1612
+ print(module)
1613
+ raise RuntimeError("error encountered during parsing")
1614
+ return module