triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/runtime/jit.py
CHANGED
|
@@ -4,6 +4,7 @@ import copy
|
|
|
4
4
|
import hashlib
|
|
5
5
|
import inspect
|
|
6
6
|
import itertools
|
|
7
|
+
import threading
|
|
7
8
|
import re
|
|
8
9
|
import textwrap
|
|
9
10
|
from collections import defaultdict
|
|
@@ -14,10 +15,14 @@ from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overlo
|
|
|
14
15
|
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
15
16
|
from types import ModuleType
|
|
16
17
|
from .. import knobs
|
|
17
|
-
from
|
|
18
|
+
from .driver import driver
|
|
19
|
+
from . import _async_compile
|
|
18
20
|
from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype
|
|
21
|
+
from .cache import get_cache_key
|
|
22
|
+
from triton._C.libtriton import get_cache_invalidating_env_vars
|
|
19
23
|
|
|
20
|
-
TRITON_MODULE =
|
|
24
|
+
TRITON_MODULE = "triton.language"
|
|
25
|
+
GLUON_MODULE = "triton.experimental.gluon.language"
|
|
21
26
|
|
|
22
27
|
T = TypeVar("T")
|
|
23
28
|
|
|
@@ -60,6 +65,12 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
60
65
|
'print',
|
|
61
66
|
'range',
|
|
62
67
|
}
|
|
68
|
+
self.supported_modules = {
|
|
69
|
+
GLUON_MODULE,
|
|
70
|
+
TRITON_MODULE,
|
|
71
|
+
"copy",
|
|
72
|
+
"math",
|
|
73
|
+
}
|
|
63
74
|
|
|
64
75
|
# used_global_vals tells us which global variables are used by this
|
|
65
76
|
# function and all those it transitively calls, plus the values of those
|
|
@@ -86,22 +97,56 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
86
97
|
return module.startswith(TRITON_MODULE)
|
|
87
98
|
|
|
88
99
|
def _update_hash(self, func):
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
100
|
+
assert isinstance(func, JITCallable)
|
|
101
|
+
# Merge our used_global_vals with those of the called function,
|
|
102
|
+
# after checking that all overlapping values are consistent.
|
|
103
|
+
for k in self.used_global_vals.keys() & func.used_global_vals.keys():
|
|
104
|
+
var_name, _ = k
|
|
105
|
+
v1, _ = self.used_global_vals[k]
|
|
106
|
+
v2, _ = func.used_global_vals[k]
|
|
107
|
+
if v1 != v2:
|
|
108
|
+
raise RuntimeError(
|
|
109
|
+
f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
|
|
110
|
+
)
|
|
111
|
+
self.used_global_vals.update(func.used_global_vals)
|
|
112
|
+
# update hash
|
|
113
|
+
func_key = func.cache_key
|
|
114
|
+
func_key += str(getattr(func, "noinline", False))
|
|
115
|
+
self.hasher.update(func_key.encode("utf-8"))
|
|
116
|
+
|
|
117
|
+
def record_reference(self, val, var_dict=None, name=None):
|
|
118
|
+
from ..language.core import constexpr
|
|
119
|
+
# Only keep track of "interesting" global variables, that non-evil users
|
|
120
|
+
# might change. Don't consider functions, modules, builtins, etc. This
|
|
121
|
+
# helps keep the list of vars we have to check small.
|
|
122
|
+
if val is None or type(val) is ModuleType:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
if getattr(val, "__triton_builtin__", False):
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
# Stubs that aren't real functions
|
|
129
|
+
if getattr(val, "__module__", "") == "triton.language.extra.libdevice":
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
if isinstance(val, JITCallable):
|
|
133
|
+
self._update_hash(val)
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
if callable(val) and not isinstance(val, type) and not isinstance(val, constexpr):
|
|
137
|
+
raise RuntimeError(f"Unsupported function referenced: {val}")
|
|
138
|
+
|
|
139
|
+
# Python default arguments are resolved only once, when the
|
|
140
|
+
# function is defined. So if you do `foo(a=A)` and the value of
|
|
141
|
+
# A changes, foo will still use the old value of A.
|
|
142
|
+
# It would be pretty evil if someone did `import x` and then
|
|
143
|
+
# `x = blah`.
|
|
144
|
+
if self.visiting_arg_default_value:
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
if var_dict is not None:
|
|
148
|
+
self.used_global_vals[(name, id(var_dict))] = (copy.deepcopy(val), var_dict)
|
|
149
|
+
return
|
|
105
150
|
|
|
106
151
|
def visit_Name(self, node):
|
|
107
152
|
if type(node.ctx) is ast.Store:
|
|
@@ -121,25 +166,10 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
121
166
|
return None, None
|
|
122
167
|
|
|
123
168
|
val, var_dict = name_lookup(node.id)
|
|
169
|
+
if node.id in self.supported_python_builtins:
|
|
170
|
+
return val
|
|
124
171
|
|
|
125
|
-
|
|
126
|
-
# might change. Don't consider functions, modules, builtins, etc. This
|
|
127
|
-
# helps keep the list of vars we have to check small.
|
|
128
|
-
if (val is not None #
|
|
129
|
-
# Python default arguments are resolved only once, when the
|
|
130
|
-
# function is defined. So if you do `foo(a=A)` and the value of
|
|
131
|
-
# A changes, foo will still use the old value of A.
|
|
132
|
-
and not self.visiting_arg_default_value
|
|
133
|
-
# It would be pretty evil if someone did `import x` and then
|
|
134
|
-
# `x = blah`.
|
|
135
|
-
and type(val) is not ModuleType
|
|
136
|
-
# It would be pretty evil if we used function `foo` inside of
|
|
137
|
-
# `bar` and then someone did `foo = baz`.
|
|
138
|
-
and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
|
|
139
|
-
and node.id not in self.supported_python_builtins):
|
|
140
|
-
self.used_global_vals[(node.id, id(var_dict))] = (copy.copy(val), var_dict)
|
|
141
|
-
|
|
142
|
-
self._update_hash(val)
|
|
172
|
+
self.record_reference(val, var_dict, node.id)
|
|
143
173
|
return val
|
|
144
174
|
|
|
145
175
|
def visit_Tuple(self, node):
|
|
@@ -151,10 +181,11 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
151
181
|
lhs = self.visit(node.value)
|
|
152
182
|
while isinstance(lhs, ast.Attribute):
|
|
153
183
|
lhs = self.visit(lhs.value)
|
|
154
|
-
|
|
184
|
+
lhs_name = getattr(lhs, "__name__", "")
|
|
185
|
+
if lhs is None or lhs_name in self.supported_modules:
|
|
155
186
|
return None
|
|
156
187
|
ret = getattr(lhs, node.attr)
|
|
157
|
-
self.
|
|
188
|
+
self.record_reference(ret)
|
|
158
189
|
return ret
|
|
159
190
|
|
|
160
191
|
def visit_FunctionDef(self, node):
|
|
@@ -345,12 +376,10 @@ def create_specialize_impl(specialize_extra):
|
|
|
345
376
|
dtype2str[dsk] = res
|
|
346
377
|
key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
|
|
347
378
|
return (res, key)
|
|
348
|
-
elif isinstance(arg,
|
|
379
|
+
elif isinstance(arg, JITCallable):
|
|
349
380
|
return ("constexpr", arg.cache_key)
|
|
350
381
|
elif isinstance(arg, constexpr):
|
|
351
382
|
return ("constexpr", arg)
|
|
352
|
-
elif hasattr(arg, "tma_desc_cpu_ptr"):
|
|
353
|
-
return ("nvTmaDesc", None)
|
|
354
383
|
elif isinstance(arg, tuple):
|
|
355
384
|
spec = [specialize_impl(x) for x in arg]
|
|
356
385
|
make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
|
|
@@ -451,7 +480,7 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
|
|
|
451
480
|
if param.default is not inspect.Parameter.empty
|
|
452
481
|
}
|
|
453
482
|
|
|
454
|
-
func_namespace["
|
|
483
|
+
func_namespace["JITCallable"] = JITCallable
|
|
455
484
|
func_namespace["specialize_impl"] = create_specialize_impl(backend.get_arg_specialization)
|
|
456
485
|
|
|
457
486
|
# Execute the function string in func_namespace to create the function
|
|
@@ -465,6 +494,104 @@ def get_full_name(fn):
|
|
|
465
494
|
return f"{fn.__module__}.{fn.__qualname__}"
|
|
466
495
|
|
|
467
496
|
|
|
497
|
+
class JITCallable:
|
|
498
|
+
|
|
499
|
+
def __init__(self, fn):
|
|
500
|
+
self.fn = fn
|
|
501
|
+
self.signature = inspect.signature(fn)
|
|
502
|
+
try:
|
|
503
|
+
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
|
|
504
|
+
except OSError as e:
|
|
505
|
+
raise ValueError("@jit functions should be defined in a Python file") from e
|
|
506
|
+
self._fn_name = get_full_name(fn)
|
|
507
|
+
self._hash_lock = threading.RLock()
|
|
508
|
+
|
|
509
|
+
# function source code (without decorators)
|
|
510
|
+
src = textwrap.dedent("".join(self.raw_src))
|
|
511
|
+
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
|
|
512
|
+
self._src = src
|
|
513
|
+
self.hash = None
|
|
514
|
+
|
|
515
|
+
# Map of global variables used by the function and any functions it
|
|
516
|
+
# transitively calls, plus their values. The values are collected when
|
|
517
|
+
# the function is first compiled. Then every time we run the function,
|
|
518
|
+
# we check that the values of the globals match what's expected,
|
|
519
|
+
# otherwise we raise an error.
|
|
520
|
+
#
|
|
521
|
+
# Different functions can have different __globals__ maps, so the map
|
|
522
|
+
# key is actually (var name, id(__globals__)), and the map value is
|
|
523
|
+
# (value, __globals__).
|
|
524
|
+
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
|
|
525
|
+
|
|
526
|
+
# reuse docs of wrapped function
|
|
527
|
+
self.__doc__ = fn.__doc__
|
|
528
|
+
self.__name__ = fn.__name__
|
|
529
|
+
self.__qualname__ = fn.__qualname__
|
|
530
|
+
self.__globals__ = fn.__globals__
|
|
531
|
+
self.__module__ = fn.__module__
|
|
532
|
+
|
|
533
|
+
def get_capture_scope(self):
|
|
534
|
+
return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
|
|
535
|
+
|
|
536
|
+
@property
|
|
537
|
+
def cache_key(self):
|
|
538
|
+
# TODO : hash should be attribute of `self`
|
|
539
|
+
with self._hash_lock:
|
|
540
|
+
if self.hash is not None:
|
|
541
|
+
return self.hash
|
|
542
|
+
# Set a placeholder hash to break recursion in case the function
|
|
543
|
+
# transitively calls itself. The full hash is set after.
|
|
544
|
+
self.hash = f"recursion:{self._fn_name}"
|
|
545
|
+
nonlocals = inspect.getclosurevars(self.fn).nonlocals
|
|
546
|
+
dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
|
|
547
|
+
src=self.src)
|
|
548
|
+
dependencies_finder.visit(self.parse())
|
|
549
|
+
self.hash = dependencies_finder.ret + str(self.starting_line_number)
|
|
550
|
+
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
|
|
551
|
+
|
|
552
|
+
from triton.language.core import constexpr
|
|
553
|
+
self.hash += str([(name, val)
|
|
554
|
+
for (name, _), (val, _) in self.used_global_vals.items()
|
|
555
|
+
if isinstance(val, constexpr)])
|
|
556
|
+
self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest()
|
|
557
|
+
return self.hash
|
|
558
|
+
|
|
559
|
+
# we do not parse `src` in the constructor because
|
|
560
|
+
# the user might want to monkey-patch self.src dynamically.
|
|
561
|
+
# Our unit tests do this, for example.
|
|
562
|
+
def parse(self):
|
|
563
|
+
tree = ast.parse(self._src)
|
|
564
|
+
assert isinstance(tree, ast.Module)
|
|
565
|
+
assert len(tree.body) == 1
|
|
566
|
+
assert isinstance(tree.body[0], ast.FunctionDef)
|
|
567
|
+
return tree
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
def type(self):
|
|
571
|
+
from triton.language.core import constexpr_type
|
|
572
|
+
return constexpr_type(self)
|
|
573
|
+
|
|
574
|
+
def _unsafe_update_src(self, new_src):
|
|
575
|
+
"""
|
|
576
|
+
The only method allowed to modify src.
|
|
577
|
+
Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
|
|
578
|
+
|
|
579
|
+
Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None.
|
|
580
|
+
"""
|
|
581
|
+
self.hash = None
|
|
582
|
+
self._src = new_src
|
|
583
|
+
|
|
584
|
+
def _set_src(self):
|
|
585
|
+
raise AttributeError("Cannot set attribute 'src' directly. "
|
|
586
|
+
"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
|
|
587
|
+
"instead.")
|
|
588
|
+
|
|
589
|
+
def _get_src(self):
|
|
590
|
+
return self._src
|
|
591
|
+
|
|
592
|
+
src = property(fget=_get_src, fset=_set_src)
|
|
593
|
+
|
|
594
|
+
|
|
468
595
|
@dataclass
|
|
469
596
|
class JitFunctionInfo:
|
|
470
597
|
module: ModuleType
|
|
@@ -472,7 +599,18 @@ class JitFunctionInfo:
|
|
|
472
599
|
jit_function: JITFunction
|
|
473
600
|
|
|
474
601
|
|
|
475
|
-
|
|
602
|
+
def compute_cache_key(kernel_key_cache, specialization, options):
|
|
603
|
+
key = (tuple(specialization), str(options))
|
|
604
|
+
cache_key = kernel_key_cache.get(key, None)
|
|
605
|
+
if cache_key is not None:
|
|
606
|
+
return cache_key
|
|
607
|
+
|
|
608
|
+
cache_key = str(specialization) + str(options)
|
|
609
|
+
kernel_key_cache[key] = cache_key
|
|
610
|
+
return cache_key
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
class JITFunction(JITCallable, KernelInterface[T]):
|
|
476
614
|
|
|
477
615
|
def is_gluon(self):
|
|
478
616
|
return False
|
|
@@ -542,7 +680,31 @@ class JITFunction(KernelInterface[T]):
|
|
|
542
680
|
self.compile = compile
|
|
543
681
|
self.ASTSource = ASTSource
|
|
544
682
|
binder = create_function_from_signature(self.signature, self.params, backend)
|
|
545
|
-
return {}, target, backend, binder
|
|
683
|
+
return {}, {}, target, backend, binder
|
|
684
|
+
|
|
685
|
+
def _pack_args(self, backend, kwargs, bound_args, specialization, options):
|
|
686
|
+
# options
|
|
687
|
+
options = backend.parse_options(kwargs)
|
|
688
|
+
# signature
|
|
689
|
+
sigkeys = [x.name for x in self.params]
|
|
690
|
+
sigvals = [x[0] for x in specialization]
|
|
691
|
+
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
|
|
692
|
+
# check arguments
|
|
693
|
+
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
|
|
694
|
+
assert "device" not in kwargs, "device option is deprecated; current device will be used"
|
|
695
|
+
assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
|
|
696
|
+
for k in kwargs:
|
|
697
|
+
if k not in options.__dict__ and k not in sigkeys:
|
|
698
|
+
raise KeyError("Keyword argument %s was specified but unrecognised" % k)
|
|
699
|
+
# constexprs
|
|
700
|
+
constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
|
|
701
|
+
constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
|
|
702
|
+
# attributes
|
|
703
|
+
attrvals = [x[1] for x in specialization]
|
|
704
|
+
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
|
705
|
+
attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
|
|
706
|
+
|
|
707
|
+
return options, signature, constexprs, attrs
|
|
546
708
|
|
|
547
709
|
def run(self, *args, grid, warmup, **kwargs):
|
|
548
710
|
kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
|
|
@@ -555,46 +717,22 @@ class JITFunction(KernelInterface[T]):
|
|
|
555
717
|
for hook in self.pre_run_hooks:
|
|
556
718
|
hook(*args, **kwargs)
|
|
557
719
|
|
|
558
|
-
kernel_cache, target, backend, binder = self.device_caches[device]
|
|
720
|
+
kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
|
|
559
721
|
# specialization is list[tuple[str, Any]], where first element of tuple is
|
|
560
722
|
# the type and the second parameter is the 'specialization' value.
|
|
561
723
|
bound_args, specialization, options = binder(*args, **kwargs)
|
|
562
724
|
|
|
563
|
-
|
|
564
|
-
key = str(specialization) + str(options)
|
|
725
|
+
key = compute_cache_key(kernel_key_cache, specialization, options)
|
|
565
726
|
kernel = kernel_cache.get(key, None)
|
|
566
727
|
|
|
567
728
|
# Kernel is not cached; we have to compile.
|
|
568
729
|
if kernel is None:
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
|
|
575
|
-
# check arguments
|
|
576
|
-
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
|
|
577
|
-
assert "device" not in kwargs, "device option is deprecated; current device will be used"
|
|
578
|
-
assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
|
|
579
|
-
for k in kwargs:
|
|
580
|
-
if k not in options.__dict__ and k not in sigkeys:
|
|
581
|
-
raise KeyError("Keyword argument %s was specified but unrecognised" % k)
|
|
582
|
-
# constexprs
|
|
583
|
-
constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
|
|
584
|
-
constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
|
|
585
|
-
# attributes
|
|
586
|
-
attrvals = [x[1] for x in specialization]
|
|
587
|
-
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
|
588
|
-
attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
|
|
589
|
-
if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs],
|
|
590
|
-
warmup):
|
|
730
|
+
options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
|
|
731
|
+
options)
|
|
732
|
+
|
|
733
|
+
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
|
|
734
|
+
if kernel is None:
|
|
591
735
|
return None
|
|
592
|
-
# compile the kernel
|
|
593
|
-
src = self.ASTSource(self, signature, constexprs, attrs)
|
|
594
|
-
kernel = self.compile(src, target=target, options=options.__dict__)
|
|
595
|
-
kernel_cache[key] = kernel
|
|
596
|
-
self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
|
|
597
|
-
warmup)
|
|
598
736
|
|
|
599
737
|
# Check that used global values have not changed.
|
|
600
738
|
not_present = object()
|
|
@@ -612,6 +750,8 @@ class JITFunction(KernelInterface[T]):
|
|
|
612
750
|
grid_0 = grid[0]
|
|
613
751
|
grid_1 = grid[1] if grid_size > 1 else 1
|
|
614
752
|
grid_2 = grid[2] if grid_size > 2 else 1
|
|
753
|
+
if hasattr(kernel, "result"):
|
|
754
|
+
kernel = kernel.result()
|
|
615
755
|
# launch kernel
|
|
616
756
|
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
|
|
617
757
|
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
|
|
@@ -626,15 +766,12 @@ class JITFunction(KernelInterface[T]):
|
|
|
626
766
|
do_not_specialize = do_not_specialize if do_not_specialize else []
|
|
627
767
|
do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
|
|
628
768
|
|
|
629
|
-
|
|
769
|
+
super().__init__(fn)
|
|
630
770
|
self.module = fn.__module__
|
|
631
771
|
self.version = version
|
|
632
|
-
self.signature = inspect.signature(fn)
|
|
633
772
|
self.do_not_specialize = do_not_specialize
|
|
634
773
|
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
|
|
635
|
-
self.starting_line_number = inspect.getsourcelines(fn)[1]
|
|
636
774
|
self._repr = repr
|
|
637
|
-
self._fn_name = get_full_name(fn)
|
|
638
775
|
self.launch_metadata = launch_metadata
|
|
639
776
|
|
|
640
777
|
self.params = []
|
|
@@ -643,24 +780,8 @@ class JITFunction(KernelInterface[T]):
|
|
|
643
780
|
dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
|
|
644
781
|
self.params.append(KernelParam(i, param, dns, dns_oa))
|
|
645
782
|
|
|
646
|
-
# function source code (without decorators)
|
|
647
|
-
src = textwrap.dedent(inspect.getsource(fn))
|
|
648
|
-
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
|
|
649
|
-
self._unsafe_update_src(src)
|
|
650
783
|
# cache of just-in-time compiled kernels
|
|
651
784
|
self.device_caches = defaultdict(self.create_binder)
|
|
652
|
-
self.hash = None
|
|
653
|
-
|
|
654
|
-
# Map of global variables used by the function and any functions it
|
|
655
|
-
# transitively calls, plus their values. The values are collected when
|
|
656
|
-
# the function is first compiled. Then every time we run the function,
|
|
657
|
-
# we check that the values of the globals match what's expected,
|
|
658
|
-
# otherwise we raise an error.
|
|
659
|
-
#
|
|
660
|
-
# Different functions can have different __globals__ maps, so the map
|
|
661
|
-
# key is actually (var name, id(__globals__)), and the map value is
|
|
662
|
-
# (value, __globals__).
|
|
663
|
-
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
|
|
664
785
|
|
|
665
786
|
# JITFunction can be instantiated as kernel
|
|
666
787
|
# when called with a grid using __getitem__
|
|
@@ -676,38 +797,10 @@ class JITFunction(KernelInterface[T]):
|
|
|
676
797
|
# Hooks that will be called prior to executing "run"
|
|
677
798
|
self.pre_run_hooks = []
|
|
678
799
|
|
|
679
|
-
# reuse docs of wrapped function
|
|
680
|
-
self.__doc__ = fn.__doc__
|
|
681
|
-
self.__name__ = fn.__name__
|
|
682
|
-
self.__qualname__ = fn.__qualname__
|
|
683
|
-
self.__globals__ = fn.__globals__
|
|
684
|
-
self.__module__ = fn.__module__
|
|
685
|
-
|
|
686
|
-
def get_capture_scope(self):
|
|
687
|
-
return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
|
|
688
|
-
|
|
689
|
-
@property
|
|
690
|
-
def cache_key(self):
|
|
691
|
-
# TODO : hash should be attribute of `self`
|
|
692
|
-
if self.hash is None:
|
|
693
|
-
nonlocals = inspect.getclosurevars(self.fn).nonlocals
|
|
694
|
-
dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
|
|
695
|
-
src=self.src)
|
|
696
|
-
dependencies_finder.visit(self.parse())
|
|
697
|
-
self.hash = dependencies_finder.ret + str(self.starting_line_number)
|
|
698
|
-
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
|
|
699
|
-
return self.hash
|
|
700
|
-
|
|
701
|
-
@property
|
|
702
|
-
def type(self):
|
|
703
|
-
from triton.language.core import constexpr
|
|
704
|
-
return constexpr
|
|
705
|
-
|
|
706
800
|
def warmup(self, *args, grid, **kwargs):
|
|
707
801
|
return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
|
|
708
802
|
|
|
709
803
|
def preload(self, specialization_data):
|
|
710
|
-
from ..compiler import compile, ASTSource
|
|
711
804
|
import json
|
|
712
805
|
import triton.language as tl
|
|
713
806
|
device = driver.active.get_current_device()
|
|
@@ -717,7 +810,7 @@ class JITFunction(KernelInterface[T]):
|
|
|
717
810
|
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
|
|
718
811
|
constant_keys = map(tuple, deserialized_obj['constant_keys'])
|
|
719
812
|
constant_vals = deserialized_obj['constant_vals']
|
|
720
|
-
|
|
813
|
+
constexprs = {
|
|
721
814
|
key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
|
|
722
815
|
for key, value in zip(constant_keys, constant_vals)
|
|
723
816
|
}
|
|
@@ -725,44 +818,54 @@ class JITFunction(KernelInterface[T]):
|
|
|
725
818
|
attrs_vals = deserialized_obj['attrs_vals']
|
|
726
819
|
attrs = dict(zip(attrs_keys, attrs_vals))
|
|
727
820
|
signature = dict(deserialized_obj['signature'].items())
|
|
728
|
-
src = ASTSource(self, signature, constants, attrs)
|
|
729
821
|
options = {
|
|
730
822
|
key: tuple(value) if isinstance(value, list) else value
|
|
731
823
|
for key, value in deserialized_obj['options'].items()
|
|
732
824
|
}
|
|
733
825
|
key = deserialized_obj['key']
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
return
|
|
826
|
+
_, _, _, backend, _ = self.device_caches[device]
|
|
827
|
+
options = backend.parse_options(options)
|
|
828
|
+
return self._do_compile(
|
|
829
|
+
key,
|
|
830
|
+
signature,
|
|
831
|
+
device,
|
|
832
|
+
constexprs,
|
|
833
|
+
options,
|
|
834
|
+
attrs,
|
|
835
|
+
warmup=True,
|
|
836
|
+
)
|
|
737
837
|
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
# Our unit tests do this, for example.
|
|
741
|
-
def parse(self):
|
|
742
|
-
tree = ast.parse(self.src)
|
|
743
|
-
assert isinstance(tree, ast.Module)
|
|
744
|
-
assert len(tree.body) == 1
|
|
745
|
-
assert isinstance(tree.body[0], ast.FunctionDef)
|
|
746
|
-
return tree
|
|
838
|
+
def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup):
|
|
839
|
+
kernel_cache, _, target, backend, _ = self.device_caches[device]
|
|
747
840
|
|
|
748
|
-
|
|
749
|
-
|
|
841
|
+
if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs], warmup):
|
|
842
|
+
return None
|
|
843
|
+
src = self.ASTSource(self, signature, constexprs, attrs)
|
|
750
844
|
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
if name == "src":
|
|
754
|
-
raise AttributeError(f"Cannot set attribute '{name}' directly. "
|
|
755
|
-
f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
|
|
756
|
-
f"instead.")
|
|
757
|
-
super(JITFunction, self).__setattr__(name, value)
|
|
845
|
+
async_mode = _async_compile.active_mode.get()
|
|
846
|
+
if async_mode is not None:
|
|
758
847
|
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
848
|
+
env_vars = get_cache_invalidating_env_vars()
|
|
849
|
+
cache_key = get_cache_key(src, backend, options, env_vars)
|
|
850
|
+
|
|
851
|
+
def async_compile():
|
|
852
|
+
return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars)
|
|
853
|
+
|
|
854
|
+
def finalize_compile(kernel):
|
|
855
|
+
kernel_cache[key] = kernel
|
|
856
|
+
self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options,
|
|
857
|
+
[attrs], warmup)
|
|
858
|
+
|
|
859
|
+
kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
|
|
860
|
+
else:
|
|
861
|
+
kernel = self.compile(src, target=target, options=options.__dict__)
|
|
862
|
+
kernel_cache[key] = kernel
|
|
863
|
+
self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
|
|
864
|
+
warmup)
|
|
865
|
+
return kernel
|
|
866
|
+
|
|
867
|
+
def __call__(self, *args, **kwargs):
|
|
868
|
+
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
|
766
869
|
|
|
767
870
|
def __repr__(self):
|
|
768
871
|
return f"JITFunction({self.module}:{self.fn.__qualname__})"
|
|
@@ -864,8 +967,17 @@ class MockTensor:
|
|
|
864
967
|
return MockTensor(arg)
|
|
865
968
|
return arg
|
|
866
969
|
|
|
867
|
-
def __init__(self, dtype):
|
|
970
|
+
def __init__(self, dtype, shape=None):
|
|
971
|
+
if shape is None:
|
|
972
|
+
shape = [1]
|
|
868
973
|
self.dtype = dtype
|
|
974
|
+
self.shape = shape
|
|
975
|
+
|
|
976
|
+
def stride(self):
|
|
977
|
+
strides = [1]
|
|
978
|
+
for size in self.shape[1:]:
|
|
979
|
+
strides.append(strides[-1] * size)
|
|
980
|
+
return tuple(reversed(strides))
|
|
869
981
|
|
|
870
982
|
@staticmethod
|
|
871
983
|
def data_ptr():
|
|
@@ -930,17 +1042,66 @@ def reinterpret(tensor, dtype):
|
|
|
930
1042
|
|
|
931
1043
|
def get_jit_fn_file_line(fn):
|
|
932
1044
|
base_fn = fn
|
|
933
|
-
while not isinstance(base_fn,
|
|
1045
|
+
while not isinstance(base_fn, JITCallable):
|
|
934
1046
|
base_fn = base_fn.fn
|
|
935
1047
|
file_name = base_fn.fn.__code__.co_filename
|
|
936
|
-
|
|
1048
|
+
begin_line = base_fn.starting_line_number
|
|
937
1049
|
# Match the following pattern:
|
|
938
1050
|
# @triton.autotune(...) <- foo.__code__.co_firstlineno
|
|
939
1051
|
# @triton.heuristics(...)
|
|
940
1052
|
# @triton.jit
|
|
941
1053
|
# def foo(...): <- this line is the first line
|
|
942
|
-
for idx, line in enumerate(
|
|
1054
|
+
for idx, line in enumerate(base_fn.raw_src):
|
|
943
1055
|
if line.strip().startswith("def "):
|
|
944
1056
|
begin_line += idx
|
|
945
1057
|
break
|
|
946
1058
|
return file_name, begin_line
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
class BoundConstexprFunction(JITCallable):
|
|
1062
|
+
|
|
1063
|
+
def __init__(self, instance, fn):
|
|
1064
|
+
self.__self__ = instance
|
|
1065
|
+
self.__func__ = fn
|
|
1066
|
+
|
|
1067
|
+
def __call__(self, *args, **kwargs):
|
|
1068
|
+
return self.__func__(self.__self__, *args, **kwargs)
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
class ConstexprFunction(JITCallable):
|
|
1072
|
+
|
|
1073
|
+
def __init__(self, fn):
|
|
1074
|
+
super().__init__(fn)
|
|
1075
|
+
|
|
1076
|
+
def __get__(self, obj, objclass):
|
|
1077
|
+
# Create a bound function to support constexpr_function methods
|
|
1078
|
+
if obj is not None:
|
|
1079
|
+
return BoundConstexprFunction(obj, self)
|
|
1080
|
+
return self
|
|
1081
|
+
|
|
1082
|
+
def __call__(self, *args, _semantic=None, **kwargs):
|
|
1083
|
+
from triton.language.core import _unwrap_if_constexpr, constexpr
|
|
1084
|
+
# de-constexpr arguments and discard the _semantic keyword argument:
|
|
1085
|
+
args = [_unwrap_if_constexpr(x) for x in args]
|
|
1086
|
+
kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
|
|
1087
|
+
|
|
1088
|
+
# call the raw Python function f:
|
|
1089
|
+
res = self.fn(*args, **kwargs)
|
|
1090
|
+
|
|
1091
|
+
if _semantic is None:
|
|
1092
|
+
# Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function
|
|
1093
|
+
return res
|
|
1094
|
+
|
|
1095
|
+
# convert result back to a Triton constexpr:
|
|
1096
|
+
if knobs.runtime.interpret:
|
|
1097
|
+
return res # No constexpr in interpreter
|
|
1098
|
+
return constexpr(res)
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def constexpr_function(fn):
|
|
1102
|
+
"""
|
|
1103
|
+
Wraps an arbitrary Python function so that it can be called at
|
|
1104
|
+
compile-time on constexpr arguments in a Triton function and
|
|
1105
|
+
returns a constexpr result.
|
|
1106
|
+
"""
|
|
1107
|
+
return ConstexprFunction(fn)
|