triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/runtime/jit.py CHANGED
@@ -4,6 +4,7 @@ import copy
4
4
  import hashlib
5
5
  import inspect
6
6
  import itertools
7
+ import threading
7
8
  import re
8
9
  import textwrap
9
10
  from collections import defaultdict
@@ -14,10 +15,14 @@ from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overlo
14
15
  from triton.tools.tensor_descriptor import TensorDescriptor
15
16
  from types import ModuleType
16
17
  from .. import knobs
17
- from ..runtime.driver import driver
18
+ from .driver import driver
19
+ from . import _async_compile
18
20
  from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype
21
+ from .cache import get_cache_key
22
+ from triton._C.libtriton import get_cache_invalidating_env_vars
19
23
 
20
- TRITON_MODULE = __name__[:-len(".runtime.jit")]
24
+ TRITON_MODULE = "triton.language"
25
+ GLUON_MODULE = "triton.experimental.gluon.language"
21
26
 
22
27
  T = TypeVar("T")
23
28
 
@@ -60,6 +65,12 @@ class DependenciesFinder(ast.NodeVisitor):
60
65
  'print',
61
66
  'range',
62
67
  }
68
+ self.supported_modules = {
69
+ GLUON_MODULE,
70
+ TRITON_MODULE,
71
+ "copy",
72
+ "math",
73
+ }
63
74
 
64
75
  # used_global_vals tells us which global variables are used by this
65
76
  # function and all those it transitively calls, plus the values of those
@@ -86,22 +97,56 @@ class DependenciesFinder(ast.NodeVisitor):
86
97
  return module.startswith(TRITON_MODULE)
87
98
 
88
99
  def _update_hash(self, func):
89
- if isinstance(func, JITFunction):
90
- # Merge our used_global_vals with those of the called function,
91
- # after checking that all overlapping values are consistent.
92
- for k in self.used_global_vals.keys() & func.used_global_vals.keys():
93
- var_name, _ = k
94
- v1, _ = self.used_global_vals[k]
95
- v2, _ = func.used_global_vals[k]
96
- if v1 != v2:
97
- raise RuntimeError(
98
- f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
99
- )
100
- self.used_global_vals.update(func.used_global_vals)
101
- # update hash
102
- func_key = func.cache_key
103
- func_key += str(getattr(func, "noinline", False))
104
- self.hasher.update(func_key.encode("utf-8"))
100
+ assert isinstance(func, JITCallable)
101
+ # Merge our used_global_vals with those of the called function,
102
+ # after checking that all overlapping values are consistent.
103
+ for k in self.used_global_vals.keys() & func.used_global_vals.keys():
104
+ var_name, _ = k
105
+ v1, _ = self.used_global_vals[k]
106
+ v2, _ = func.used_global_vals[k]
107
+ if v1 != v2:
108
+ raise RuntimeError(
109
+ f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
110
+ )
111
+ self.used_global_vals.update(func.used_global_vals)
112
+ # update hash
113
+ func_key = func.cache_key
114
+ func_key += str(getattr(func, "noinline", False))
115
+ self.hasher.update(func_key.encode("utf-8"))
116
+
117
+ def record_reference(self, val, var_dict=None, name=None):
118
+ from ..language.core import constexpr
119
+ # Only keep track of "interesting" global variables, that non-evil users
120
+ # might change. Don't consider functions, modules, builtins, etc. This
121
+ # helps keep the list of vars we have to check small.
122
+ if val is None or type(val) is ModuleType:
123
+ return
124
+
125
+ if getattr(val, "__triton_builtin__", False):
126
+ return
127
+
128
+ # Stubs that aren't real functions
129
+ if getattr(val, "__module__", "") == "triton.language.extra.libdevice":
130
+ return
131
+
132
+ if isinstance(val, JITCallable):
133
+ self._update_hash(val)
134
+ return
135
+
136
+ if callable(val) and not isinstance(val, type) and not isinstance(val, constexpr):
137
+ raise RuntimeError(f"Unsupported function referenced: {val}")
138
+
139
+ # Python default arguments are resolved only once, when the
140
+ # function is defined. So if you do `foo(a=A)` and the value of
141
+ # A changes, foo will still use the old value of A.
142
+ # It would be pretty evil if someone did `import x` and then
143
+ # `x = blah`.
144
+ if self.visiting_arg_default_value:
145
+ return
146
+
147
+ if var_dict is not None:
148
+ self.used_global_vals[(name, id(var_dict))] = (copy.deepcopy(val), var_dict)
149
+ return
105
150
 
106
151
  def visit_Name(self, node):
107
152
  if type(node.ctx) is ast.Store:
@@ -121,25 +166,10 @@ class DependenciesFinder(ast.NodeVisitor):
121
166
  return None, None
122
167
 
123
168
  val, var_dict = name_lookup(node.id)
169
+ if node.id in self.supported_python_builtins:
170
+ return val
124
171
 
125
- # Only keep track of "interesting" global variables, that non-evil users
126
- # might change. Don't consider functions, modules, builtins, etc. This
127
- # helps keep the list of vars we have to check small.
128
- if (val is not None #
129
- # Python default arguments are resolved only once, when the
130
- # function is defined. So if you do `foo(a=A)` and the value of
131
- # A changes, foo will still use the old value of A.
132
- and not self.visiting_arg_default_value
133
- # It would be pretty evil if someone did `import x` and then
134
- # `x = blah`.
135
- and type(val) is not ModuleType
136
- # It would be pretty evil if we used function `foo` inside of
137
- # `bar` and then someone did `foo = baz`.
138
- and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
139
- and node.id not in self.supported_python_builtins):
140
- self.used_global_vals[(node.id, id(var_dict))] = (copy.copy(val), var_dict)
141
-
142
- self._update_hash(val)
172
+ self.record_reference(val, var_dict, node.id)
143
173
  return val
144
174
 
145
175
  def visit_Tuple(self, node):
@@ -151,10 +181,11 @@ class DependenciesFinder(ast.NodeVisitor):
151
181
  lhs = self.visit(node.value)
152
182
  while isinstance(lhs, ast.Attribute):
153
183
  lhs = self.visit(lhs.value)
154
- if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE):
184
+ lhs_name = getattr(lhs, "__name__", "")
185
+ if lhs is None or lhs_name in self.supported_modules:
155
186
  return None
156
187
  ret = getattr(lhs, node.attr)
157
- self._update_hash(ret)
188
+ self.record_reference(ret)
158
189
  return ret
159
190
 
160
191
  def visit_FunctionDef(self, node):
@@ -345,12 +376,10 @@ def create_specialize_impl(specialize_extra):
345
376
  dtype2str[dsk] = res
346
377
  key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
347
378
  return (res, key)
348
- elif isinstance(arg, JITFunction):
379
+ elif isinstance(arg, JITCallable):
349
380
  return ("constexpr", arg.cache_key)
350
381
  elif isinstance(arg, constexpr):
351
382
  return ("constexpr", arg)
352
- elif hasattr(arg, "tma_desc_cpu_ptr"):
353
- return ("nvTmaDesc", None)
354
383
  elif isinstance(arg, tuple):
355
384
  spec = [specialize_impl(x) for x in arg]
356
385
  make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
@@ -451,7 +480,7 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
451
480
  if param.default is not inspect.Parameter.empty
452
481
  }
453
482
 
454
- func_namespace["JITFunction"] = JITFunction
483
+ func_namespace["JITCallable"] = JITCallable
455
484
  func_namespace["specialize_impl"] = create_specialize_impl(backend.get_arg_specialization)
456
485
 
457
486
  # Execute the function string in func_namespace to create the function
@@ -465,6 +494,104 @@ def get_full_name(fn):
465
494
  return f"{fn.__module__}.{fn.__qualname__}"
466
495
 
467
496
 
497
+ class JITCallable:
498
+
499
+ def __init__(self, fn):
500
+ self.fn = fn
501
+ self.signature = inspect.signature(fn)
502
+ try:
503
+ self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
504
+ except OSError as e:
505
+ raise ValueError("@jit functions should be defined in a Python file") from e
506
+ self._fn_name = get_full_name(fn)
507
+ self._hash_lock = threading.RLock()
508
+
509
+ # function source code (without decorators)
510
+ src = textwrap.dedent("".join(self.raw_src))
511
+ src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
512
+ self._src = src
513
+ self.hash = None
514
+
515
+ # Map of global variables used by the function and any functions it
516
+ # transitively calls, plus their values. The values are collected when
517
+ # the function is first compiled. Then every time we run the function,
518
+ # we check that the values of the globals match what's expected,
519
+ # otherwise we raise an error.
520
+ #
521
+ # Different functions can have different __globals__ maps, so the map
522
+ # key is actually (var name, id(__globals__)), and the map value is
523
+ # (value, __globals__).
524
+ self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
525
+
526
+ # reuse docs of wrapped function
527
+ self.__doc__ = fn.__doc__
528
+ self.__name__ = fn.__name__
529
+ self.__qualname__ = fn.__qualname__
530
+ self.__globals__ = fn.__globals__
531
+ self.__module__ = fn.__module__
532
+
533
+ def get_capture_scope(self):
534
+ return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
535
+
536
+ @property
537
+ def cache_key(self):
538
+ # TODO : hash should be attribute of `self`
539
+ with self._hash_lock:
540
+ if self.hash is not None:
541
+ return self.hash
542
+ # Set a placeholder hash to break recursion in case the function
543
+ # transitively calls itself. The full hash is set after.
544
+ self.hash = f"recursion:{self._fn_name}"
545
+ nonlocals = inspect.getclosurevars(self.fn).nonlocals
546
+ dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
547
+ src=self.src)
548
+ dependencies_finder.visit(self.parse())
549
+ self.hash = dependencies_finder.ret + str(self.starting_line_number)
550
+ self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
551
+
552
+ from triton.language.core import constexpr
553
+ self.hash += str([(name, val)
554
+ for (name, _), (val, _) in self.used_global_vals.items()
555
+ if isinstance(val, constexpr)])
556
+ self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest()
557
+ return self.hash
558
+
559
+ # we do not parse `src` in the constructor because
560
+ # the user might want to monkey-patch self.src dynamically.
561
+ # Our unit tests do this, for example.
562
+ def parse(self):
563
+ tree = ast.parse(self._src)
564
+ assert isinstance(tree, ast.Module)
565
+ assert len(tree.body) == 1
566
+ assert isinstance(tree.body[0], ast.FunctionDef)
567
+ return tree
568
+
569
+ @property
570
+ def type(self):
571
+ from triton.language.core import constexpr_type
572
+ return constexpr_type(self)
573
+
574
+ def _unsafe_update_src(self, new_src):
575
+ """
576
+ The only method allowed to modify src.
577
+ Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
578
+
579
+ Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None.
580
+ """
581
+ self.hash = None
582
+ self._src = new_src
583
+
584
+ def _set_src(self):
585
+ raise AttributeError("Cannot set attribute 'src' directly. "
586
+ "Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
587
+ "instead.")
588
+
589
+ def _get_src(self):
590
+ return self._src
591
+
592
+ src = property(fget=_get_src, fset=_set_src)
593
+
594
+
468
595
  @dataclass
469
596
  class JitFunctionInfo:
470
597
  module: ModuleType
@@ -472,7 +599,18 @@ class JitFunctionInfo:
472
599
  jit_function: JITFunction
473
600
 
474
601
 
475
- class JITFunction(KernelInterface[T]):
602
+ def compute_cache_key(kernel_key_cache, specialization, options):
603
+ key = (tuple(specialization), str(options))
604
+ cache_key = kernel_key_cache.get(key, None)
605
+ if cache_key is not None:
606
+ return cache_key
607
+
608
+ cache_key = str(specialization) + str(options)
609
+ kernel_key_cache[key] = cache_key
610
+ return cache_key
611
+
612
+
613
+ class JITFunction(JITCallable, KernelInterface[T]):
476
614
 
477
615
  def is_gluon(self):
478
616
  return False
@@ -542,7 +680,31 @@ class JITFunction(KernelInterface[T]):
542
680
  self.compile = compile
543
681
  self.ASTSource = ASTSource
544
682
  binder = create_function_from_signature(self.signature, self.params, backend)
545
- return {}, target, backend, binder
683
+ return {}, {}, target, backend, binder
684
+
685
+ def _pack_args(self, backend, kwargs, bound_args, specialization, options):
686
+ # options
687
+ options = backend.parse_options(kwargs)
688
+ # signature
689
+ sigkeys = [x.name for x in self.params]
690
+ sigvals = [x[0] for x in specialization]
691
+ signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
692
+ # check arguments
693
+ assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
694
+ assert "device" not in kwargs, "device option is deprecated; current device will be used"
695
+ assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
696
+ for k in kwargs:
697
+ if k not in options.__dict__ and k not in sigkeys:
698
+ raise KeyError("Keyword argument %s was specified but unrecognised" % k)
699
+ # constexprs
700
+ constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
701
+ constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
702
+ # attributes
703
+ attrvals = [x[1] for x in specialization]
704
+ attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
705
+ attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
706
+
707
+ return options, signature, constexprs, attrs
546
708
 
547
709
  def run(self, *args, grid, warmup, **kwargs):
548
710
  kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
@@ -555,46 +717,22 @@ class JITFunction(KernelInterface[T]):
555
717
  for hook in self.pre_run_hooks:
556
718
  hook(*args, **kwargs)
557
719
 
558
- kernel_cache, target, backend, binder = self.device_caches[device]
720
+ kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
559
721
  # specialization is list[tuple[str, Any]], where first element of tuple is
560
722
  # the type and the second parameter is the 'specialization' value.
561
723
  bound_args, specialization, options = binder(*args, **kwargs)
562
724
 
563
- # compute cache key
564
- key = str(specialization) + str(options)
725
+ key = compute_cache_key(kernel_key_cache, specialization, options)
565
726
  kernel = kernel_cache.get(key, None)
566
727
 
567
728
  # Kernel is not cached; we have to compile.
568
729
  if kernel is None:
569
- # options
570
- options = backend.parse_options(kwargs)
571
- # signature
572
- sigkeys = [x.name for x in self.params]
573
- sigvals = [x[0] for x in specialization]
574
- signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
575
- # check arguments
576
- assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
577
- assert "device" not in kwargs, "device option is deprecated; current device will be used"
578
- assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
579
- for k in kwargs:
580
- if k not in options.__dict__ and k not in sigkeys:
581
- raise KeyError("Keyword argument %s was specified but unrecognised" % k)
582
- # constexprs
583
- constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
584
- constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
585
- # attributes
586
- attrvals = [x[1] for x in specialization]
587
- attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
588
- attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
589
- if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs],
590
- warmup):
730
+ options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
731
+ options)
732
+
733
+ kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
734
+ if kernel is None:
591
735
  return None
592
- # compile the kernel
593
- src = self.ASTSource(self, signature, constexprs, attrs)
594
- kernel = self.compile(src, target=target, options=options.__dict__)
595
- kernel_cache[key] = kernel
596
- self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
597
- warmup)
598
736
 
599
737
  # Check that used global values have not changed.
600
738
  not_present = object()
@@ -612,6 +750,8 @@ class JITFunction(KernelInterface[T]):
612
750
  grid_0 = grid[0]
613
751
  grid_1 = grid[1] if grid_size > 1 else 1
614
752
  grid_2 = grid[2] if grid_size > 2 else 1
753
+ if hasattr(kernel, "result"):
754
+ kernel = kernel.result()
615
755
  # launch kernel
616
756
  launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
617
757
  kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
@@ -626,15 +766,12 @@ class JITFunction(KernelInterface[T]):
626
766
  do_not_specialize = do_not_specialize if do_not_specialize else []
627
767
  do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
628
768
 
629
- self.fn = fn
769
+ super().__init__(fn)
630
770
  self.module = fn.__module__
631
771
  self.version = version
632
- self.signature = inspect.signature(fn)
633
772
  self.do_not_specialize = do_not_specialize
634
773
  self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
635
- self.starting_line_number = inspect.getsourcelines(fn)[1]
636
774
  self._repr = repr
637
- self._fn_name = get_full_name(fn)
638
775
  self.launch_metadata = launch_metadata
639
776
 
640
777
  self.params = []
@@ -643,24 +780,8 @@ class JITFunction(KernelInterface[T]):
643
780
  dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
644
781
  self.params.append(KernelParam(i, param, dns, dns_oa))
645
782
 
646
- # function source code (without decorators)
647
- src = textwrap.dedent(inspect.getsource(fn))
648
- src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
649
- self._unsafe_update_src(src)
650
783
  # cache of just-in-time compiled kernels
651
784
  self.device_caches = defaultdict(self.create_binder)
652
- self.hash = None
653
-
654
- # Map of global variables used by the function and any functions it
655
- # transitively calls, plus their values. The values are collected when
656
- # the function is first compiled. Then every time we run the function,
657
- # we check that the values of the globals match what's expected,
658
- # otherwise we raise an error.
659
- #
660
- # Different functions can have different __globals__ maps, so the map
661
- # key is actually (var name, id(__globals__)), and the map value is
662
- # (value, __globals__).
663
- self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
664
785
 
665
786
  # JITFunction can be instantiated as kernel
666
787
  # when called with a grid using __getitem__
@@ -676,38 +797,10 @@ class JITFunction(KernelInterface[T]):
676
797
  # Hooks that will be called prior to executing "run"
677
798
  self.pre_run_hooks = []
678
799
 
679
- # reuse docs of wrapped function
680
- self.__doc__ = fn.__doc__
681
- self.__name__ = fn.__name__
682
- self.__qualname__ = fn.__qualname__
683
- self.__globals__ = fn.__globals__
684
- self.__module__ = fn.__module__
685
-
686
- def get_capture_scope(self):
687
- return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
688
-
689
- @property
690
- def cache_key(self):
691
- # TODO : hash should be attribute of `self`
692
- if self.hash is None:
693
- nonlocals = inspect.getclosurevars(self.fn).nonlocals
694
- dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
695
- src=self.src)
696
- dependencies_finder.visit(self.parse())
697
- self.hash = dependencies_finder.ret + str(self.starting_line_number)
698
- self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
699
- return self.hash
700
-
701
- @property
702
- def type(self):
703
- from triton.language.core import constexpr
704
- return constexpr
705
-
706
800
  def warmup(self, *args, grid, **kwargs):
707
801
  return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
708
802
 
709
803
  def preload(self, specialization_data):
710
- from ..compiler import compile, ASTSource
711
804
  import json
712
805
  import triton.language as tl
713
806
  device = driver.active.get_current_device()
@@ -717,7 +810,7 @@ class JITFunction(KernelInterface[T]):
717
810
  f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
718
811
  constant_keys = map(tuple, deserialized_obj['constant_keys'])
719
812
  constant_vals = deserialized_obj['constant_vals']
720
- constants = {
813
+ constexprs = {
721
814
  key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
722
815
  for key, value in zip(constant_keys, constant_vals)
723
816
  }
@@ -725,44 +818,54 @@ class JITFunction(KernelInterface[T]):
725
818
  attrs_vals = deserialized_obj['attrs_vals']
726
819
  attrs = dict(zip(attrs_keys, attrs_vals))
727
820
  signature = dict(deserialized_obj['signature'].items())
728
- src = ASTSource(self, signature, constants, attrs)
729
821
  options = {
730
822
  key: tuple(value) if isinstance(value, list) else value
731
823
  for key, value in deserialized_obj['options'].items()
732
824
  }
733
825
  key = deserialized_obj['key']
734
- kernel = compile(src, None, options)
735
- self.device_caches[device][0][key] = kernel
736
- return kernel
826
+ _, _, _, backend, _ = self.device_caches[device]
827
+ options = backend.parse_options(options)
828
+ return self._do_compile(
829
+ key,
830
+ signature,
831
+ device,
832
+ constexprs,
833
+ options,
834
+ attrs,
835
+ warmup=True,
836
+ )
737
837
 
738
- # we do not parse `src` in the constructor because
739
- # the user might want to monkey-patch self.src dynamically.
740
- # Our unit tests do this, for example.
741
- def parse(self):
742
- tree = ast.parse(self.src)
743
- assert isinstance(tree, ast.Module)
744
- assert len(tree.body) == 1
745
- assert isinstance(tree.body[0], ast.FunctionDef)
746
- return tree
838
+ def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup):
839
+ kernel_cache, _, target, backend, _ = self.device_caches[device]
747
840
 
748
- def __call__(self, *args, **kwargs):
749
- raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
841
+ if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs], warmup):
842
+ return None
843
+ src = self.ASTSource(self, signature, constexprs, attrs)
750
844
 
751
- def __setattr__(self, name, value):
752
- # - when `.src` attribute is set, cache key of all callers need to be re-computed
753
- if name == "src":
754
- raise AttributeError(f"Cannot set attribute '{name}' directly. "
755
- f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
756
- f"instead.")
757
- super(JITFunction, self).__setattr__(name, value)
845
+ async_mode = _async_compile.active_mode.get()
846
+ if async_mode is not None:
758
847
 
759
- def _unsafe_update_src(self, new_src):
760
- """
761
- The only method allowed to modify src.
762
- Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
763
- """
764
- self.hash = None
765
- super().__setattr__('src', new_src)
848
+ env_vars = get_cache_invalidating_env_vars()
849
+ cache_key = get_cache_key(src, backend, options, env_vars)
850
+
851
+ def async_compile():
852
+ return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars)
853
+
854
+ def finalize_compile(kernel):
855
+ kernel_cache[key] = kernel
856
+ self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options,
857
+ [attrs], warmup)
858
+
859
+ kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
860
+ else:
861
+ kernel = self.compile(src, target=target, options=options.__dict__)
862
+ kernel_cache[key] = kernel
863
+ self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
864
+ warmup)
865
+ return kernel
866
+
867
+ def __call__(self, *args, **kwargs):
868
+ raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
766
869
 
767
870
  def __repr__(self):
768
871
  return f"JITFunction({self.module}:{self.fn.__qualname__})"
@@ -864,8 +967,17 @@ class MockTensor:
864
967
  return MockTensor(arg)
865
968
  return arg
866
969
 
867
- def __init__(self, dtype):
970
+ def __init__(self, dtype, shape=None):
971
+ if shape is None:
972
+ shape = [1]
868
973
  self.dtype = dtype
974
+ self.shape = shape
975
+
976
+ def stride(self):
977
+ strides = [1]
978
+ for size in self.shape[1:]:
979
+ strides.append(strides[-1] * size)
980
+ return tuple(reversed(strides))
869
981
 
870
982
  @staticmethod
871
983
  def data_ptr():
@@ -930,17 +1042,66 @@ def reinterpret(tensor, dtype):
930
1042
 
931
1043
  def get_jit_fn_file_line(fn):
932
1044
  base_fn = fn
933
- while not isinstance(base_fn, JITFunction):
1045
+ while not isinstance(base_fn, JITCallable):
934
1046
  base_fn = base_fn.fn
935
1047
  file_name = base_fn.fn.__code__.co_filename
936
- lines, begin_line = inspect.getsourcelines(base_fn.fn)
1048
+ begin_line = base_fn.starting_line_number
937
1049
  # Match the following pattern:
938
1050
  # @triton.autotune(...) <- foo.__code__.co_firstlineno
939
1051
  # @triton.heuristics(...)
940
1052
  # @triton.jit
941
1053
  # def foo(...): <- this line is the first line
942
- for idx, line in enumerate(lines):
1054
+ for idx, line in enumerate(base_fn.raw_src):
943
1055
  if line.strip().startswith("def "):
944
1056
  begin_line += idx
945
1057
  break
946
1058
  return file_name, begin_line
1059
+
1060
+
1061
+ class BoundConstexprFunction(JITCallable):
1062
+
1063
+ def __init__(self, instance, fn):
1064
+ self.__self__ = instance
1065
+ self.__func__ = fn
1066
+
1067
+ def __call__(self, *args, **kwargs):
1068
+ return self.__func__(self.__self__, *args, **kwargs)
1069
+
1070
+
1071
+ class ConstexprFunction(JITCallable):
1072
+
1073
+ def __init__(self, fn):
1074
+ super().__init__(fn)
1075
+
1076
+ def __get__(self, obj, objclass):
1077
+ # Create a bound function to support constexpr_function methods
1078
+ if obj is not None:
1079
+ return BoundConstexprFunction(obj, self)
1080
+ return self
1081
+
1082
+ def __call__(self, *args, _semantic=None, **kwargs):
1083
+ from triton.language.core import _unwrap_if_constexpr, constexpr
1084
+ # de-constexpr arguments and discard the _semantic keyword argument:
1085
+ args = [_unwrap_if_constexpr(x) for x in args]
1086
+ kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
1087
+
1088
+ # call the raw Python function f:
1089
+ res = self.fn(*args, **kwargs)
1090
+
1091
+ if _semantic is None:
1092
+ # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function
1093
+ return res
1094
+
1095
+ # convert result back to a Triton constexpr:
1096
+ if knobs.runtime.interpret:
1097
+ return res # No constexpr in interpreter
1098
+ return constexpr(res)
1099
+
1100
+
1101
+ def constexpr_function(fn):
1102
+ """
1103
+ Wraps an arbitrary Python function so that it can be called at
1104
+ compile-time on constexpr arguments in a Triton function and
1105
+ returns a constexpr result.
1106
+ """
1107
+ return ConstexprFunction(fn)