triton-windows 3.2.0.post11__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +85 -0
- triton/_internal_testing.py +123 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +368 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +512 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +304 -0
- triton/backends/driver.py +48 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +410 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +524 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1303 -0
- triton/compiler/compiler.py +430 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +294 -0
- triton/language/_utils.py +21 -0
- triton/language/core.py +2694 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +13 -0
- triton/language/extra/cuda/_experimental_tma.py +108 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +475 -0
- triton/language/extra/libdevice.py +786 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1796 -0
- triton/language/standard.py +452 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +408 -0
- triton/runtime/build.py +111 -0
- triton/runtime/cache.py +295 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1235 -0
- triton/runtime/jit.py +951 -0
- triton/testing.py +511 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +155 -0
- triton/tools/disasm.py +144 -0
- triton/tools/experimental_descriptor.py +32 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +375 -0
- triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
- triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
- triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
- triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
triton/runtime/jit.py
ADDED
|
@@ -0,0 +1,951 @@
|
|
|
1
|
+
from __future__ import annotations, division
|
|
2
|
+
import ast
|
|
3
|
+
import hashlib
|
|
4
|
+
import inspect
|
|
5
|
+
import itertools
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import textwrap
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
|
|
12
|
+
from ..runtime.driver import driver
|
|
13
|
+
from types import ModuleType
|
|
14
|
+
|
|
15
|
+
TRITON_MODULE = __name__[:-len(".runtime.jit")]
|
|
16
|
+
|
|
17
|
+
T = TypeVar("T")
|
|
18
|
+
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
# Dependencies Finder
|
|
21
|
+
# -----------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DependenciesFinder(ast.NodeVisitor):
|
|
25
|
+
"""
|
|
26
|
+
This AST visitor is used to find dependencies of a JITFunction. This can
|
|
27
|
+
be used to invalidate a JITFunction's hash when its source code -- or
|
|
28
|
+
that of its dependencies -- changes.
|
|
29
|
+
|
|
30
|
+
This visitor also keeps track of the global variables touched by the
|
|
31
|
+
JITFunction. When we launch the kernel, we check that these have the same
|
|
32
|
+
values as they did when we ran this visitor. If not, we raise an error (or
|
|
33
|
+
otherwise we could recompile).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, name, globals, src) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.name = name
|
|
39
|
+
self.hasher = hashlib.sha256(src.encode("utf-8"))
|
|
40
|
+
|
|
41
|
+
# This function's __globals__ dict.
|
|
42
|
+
self.globals = globals
|
|
43
|
+
|
|
44
|
+
# Python builtins that can be accessed from Triton kernels.
|
|
45
|
+
self.supported_python_builtins = {
|
|
46
|
+
'float',
|
|
47
|
+
'getattr',
|
|
48
|
+
'int',
|
|
49
|
+
'isinstance',
|
|
50
|
+
'len',
|
|
51
|
+
'list',
|
|
52
|
+
'max',
|
|
53
|
+
'min',
|
|
54
|
+
'print',
|
|
55
|
+
'range',
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
# used_global_vals tells us which global variables are used by this
|
|
59
|
+
# function and all those it transitively calls, plus the values of those
|
|
60
|
+
# variables when each function was initially run. (That is, if A calls
|
|
61
|
+
# C, and B calls C, then the values for C in used_global_vals will be
|
|
62
|
+
# from the first time C was run, either by A or B.)
|
|
63
|
+
#
|
|
64
|
+
# Each function may have a different __globals__ dict, so the global
|
|
65
|
+
# variable `foo` may actually have a different value in the different
|
|
66
|
+
# functions. Thus this map is actually
|
|
67
|
+
# (var_name, id(__globals__)) -> (var_value, __globals__).
|
|
68
|
+
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
|
|
69
|
+
|
|
70
|
+
self.visiting_arg_default_value = False
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def ret(self):
|
|
74
|
+
return self.hasher.hexdigest()
|
|
75
|
+
|
|
76
|
+
def _is_triton_builtin(self, node, func):
|
|
77
|
+
if inspect.isbuiltin(node.func):
|
|
78
|
+
return True
|
|
79
|
+
module = getattr(func, "__module__", "")
|
|
80
|
+
return module.startswith(TRITON_MODULE)
|
|
81
|
+
|
|
82
|
+
def _update_hash(self, func):
|
|
83
|
+
if isinstance(func, JITFunction):
|
|
84
|
+
# Merge our used_global_vals with those of the called function,
|
|
85
|
+
# after checking that all overlapping values are consistent.
|
|
86
|
+
for k in self.used_global_vals.keys() & func.used_global_vals.keys():
|
|
87
|
+
var_name, _ = k
|
|
88
|
+
v1, _ = self.used_global_vals[k]
|
|
89
|
+
v2, _ = func.used_global_vals[k]
|
|
90
|
+
if v1 != v2:
|
|
91
|
+
raise RuntimeError(
|
|
92
|
+
f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
|
|
93
|
+
)
|
|
94
|
+
self.used_global_vals.update(func.used_global_vals)
|
|
95
|
+
# update hash
|
|
96
|
+
func_key = func.cache_key
|
|
97
|
+
func_key += str(getattr(func, "noinline", False))
|
|
98
|
+
self.hasher.update(func_key.encode("utf-8"))
|
|
99
|
+
|
|
100
|
+
def visit_Name(self, node):
|
|
101
|
+
if type(node.ctx) is ast.Store:
|
|
102
|
+
return node.id
|
|
103
|
+
|
|
104
|
+
if node.id in self.local_names:
|
|
105
|
+
# The global name is hidden by the local name.
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
val = self.globals.get(node.id, None)
|
|
109
|
+
|
|
110
|
+
# Only keep track of "interesting" global variables, that non-evil users
|
|
111
|
+
# might change. Don't consider functions, modules, builtins, etc. This
|
|
112
|
+
# helps keep the list of vars we have to check small.
|
|
113
|
+
if (val is not None #
|
|
114
|
+
# Python default arguments are resolved only once, when the
|
|
115
|
+
# function is defined. So if you do `foo(a=A)` and the value of
|
|
116
|
+
# A changes, foo will still use the old value of A.
|
|
117
|
+
and not self.visiting_arg_default_value
|
|
118
|
+
# It would be pretty evil if someone did `import x` and then
|
|
119
|
+
# `x = blah`.
|
|
120
|
+
and type(val) is not ModuleType
|
|
121
|
+
# It would be pretty evil if we used function `foo` inside of
|
|
122
|
+
# `bar` and then someone did `foo = baz`.
|
|
123
|
+
and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
|
|
124
|
+
and node.id not in self.supported_python_builtins):
|
|
125
|
+
self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals)
|
|
126
|
+
|
|
127
|
+
self._update_hash(val)
|
|
128
|
+
return val
|
|
129
|
+
|
|
130
|
+
def visit_Tuple(self, node):
|
|
131
|
+
# We need to explicitly return the tuple values so that visit_Assign can
|
|
132
|
+
# access them in the case of `a, b = ...`.
|
|
133
|
+
return [self.visit(elt) for elt in node.elts]
|
|
134
|
+
|
|
135
|
+
def visit_Attribute(self, node):
|
|
136
|
+
lhs = self.visit(node.value)
|
|
137
|
+
while isinstance(lhs, ast.Attribute):
|
|
138
|
+
lhs = self.visit(lhs.value)
|
|
139
|
+
if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE):
|
|
140
|
+
return None
|
|
141
|
+
ret = getattr(lhs, node.attr)
|
|
142
|
+
self._update_hash(ret)
|
|
143
|
+
return ret
|
|
144
|
+
|
|
145
|
+
def visit_FunctionDef(self, node):
|
|
146
|
+
# Save the local name, which may hide the global name.
|
|
147
|
+
self.local_names = {arg.arg for arg in node.args.args}
|
|
148
|
+
self.generic_visit(node)
|
|
149
|
+
|
|
150
|
+
def visit_arguments(self, node):
|
|
151
|
+
# The purpose of this function is to visit everything in `arguments`
|
|
152
|
+
# just like `generic_visit`, except when we're visiting default values
|
|
153
|
+
# (i.e. the `foo` part of `def fn(x = foo)`), we set
|
|
154
|
+
# self.visiting_arg_default_value = True. This allows visit_Name to be
|
|
155
|
+
# aware that we're inside function default values, which have special
|
|
156
|
+
# semantics.
|
|
157
|
+
|
|
158
|
+
# According to the AST docs, the arguments node has the following structure.
|
|
159
|
+
#
|
|
160
|
+
# arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs,
|
|
161
|
+
# expr* kw_defaults, arg? kwarg, expr* defaults)
|
|
162
|
+
def visit_defaults(defaults):
|
|
163
|
+
try:
|
|
164
|
+
assert not self.visiting_arg_default_value
|
|
165
|
+
self.visiting_arg_default_value = True
|
|
166
|
+
for expr in defaults:
|
|
167
|
+
if expr is not None:
|
|
168
|
+
self.visit(expr)
|
|
169
|
+
finally:
|
|
170
|
+
self.visiting_arg_default_value = False
|
|
171
|
+
|
|
172
|
+
for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs):
|
|
173
|
+
self.visit(arg)
|
|
174
|
+
|
|
175
|
+
visit_defaults(node.kw_defaults)
|
|
176
|
+
|
|
177
|
+
if node.kwarg is not None:
|
|
178
|
+
self.visit(node.kwarg)
|
|
179
|
+
|
|
180
|
+
visit_defaults(node.defaults)
|
|
181
|
+
|
|
182
|
+
def visitAssnTarget(self, node):
|
|
183
|
+
# Target is either a single string, or a list of strings (if the assn
|
|
184
|
+
# target is a tuple).
|
|
185
|
+
target = self.visit(node)
|
|
186
|
+
if isinstance(target, list):
|
|
187
|
+
self.local_names |= set(target)
|
|
188
|
+
else:
|
|
189
|
+
self.local_names.add(target)
|
|
190
|
+
|
|
191
|
+
def visit_Assign(self, node):
|
|
192
|
+
if len(node.targets) != 1:
|
|
193
|
+
# TODO(jlebar): I don't actually know how to hit this. You don't
|
|
194
|
+
# get it from `a, b = ...` -- in that case, node.targets is a single
|
|
195
|
+
# Tuple, and in fact we *do* need to handle that case if we want
|
|
196
|
+
# existing code to work.
|
|
197
|
+
raise TypeError("Simultaneous multiple assignment is not supported.")
|
|
198
|
+
|
|
199
|
+
self.visitAssnTarget(node.targets[0])
|
|
200
|
+
|
|
201
|
+
# This will re-visit the target, but that's OK.
|
|
202
|
+
self.generic_visit(node)
|
|
203
|
+
|
|
204
|
+
def visit_AnnAssign(self, node):
|
|
205
|
+
self.visitAssnTarget(node.target)
|
|
206
|
+
|
|
207
|
+
# This will re-visit the target, but that's OK.
|
|
208
|
+
self.generic_visit(node)
|
|
209
|
+
|
|
210
|
+
def visit_For(self, node):
|
|
211
|
+
self.visitAssnTarget(node.target)
|
|
212
|
+
|
|
213
|
+
# This will re-visit the target, but that's fine.
|
|
214
|
+
self.generic_visit(node)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# -----------------------------------------------------------------------------
|
|
218
|
+
# JITFunction
|
|
219
|
+
# -----------------------------------------------------------------------------
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _normalize_ty(ty) -> str:
|
|
223
|
+
if isinstance(ty, type):
|
|
224
|
+
return ty.__name__
|
|
225
|
+
elif isinstance(ty, str):
|
|
226
|
+
return ty
|
|
227
|
+
return repr(ty)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class KernelParam:
|
|
231
|
+
"""Represents a parameter (name plus metadata) to a @jit'ed function."""
|
|
232
|
+
|
|
233
|
+
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool,
|
|
234
|
+
do_not_specialize_on_alignment: bool):
|
|
235
|
+
self.num = num
|
|
236
|
+
self._param = param
|
|
237
|
+
self.do_not_specialize = do_not_specialize
|
|
238
|
+
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
|
|
239
|
+
|
|
240
|
+
@cached_property
|
|
241
|
+
def name(self):
|
|
242
|
+
return self._param.name
|
|
243
|
+
|
|
244
|
+
@cached_property
|
|
245
|
+
def annotation(self):
|
|
246
|
+
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
|
|
247
|
+
return ""
|
|
248
|
+
return _normalize_ty(self._param.annotation)
|
|
249
|
+
|
|
250
|
+
@cached_property
|
|
251
|
+
def annotation_type(self):
|
|
252
|
+
annotation = self.annotation
|
|
253
|
+
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
|
|
254
|
+
width = annotation[annotation.find(ty1) + len(ty1):]
|
|
255
|
+
if width and ty1 in annotation:
|
|
256
|
+
return f"{ty2}{width}"
|
|
257
|
+
if annotation == "bool":
|
|
258
|
+
return "u1"
|
|
259
|
+
return ""
|
|
260
|
+
|
|
261
|
+
@cached_property
|
|
262
|
+
def is_constexpr(self):
|
|
263
|
+
return "constexpr" in self.annotation
|
|
264
|
+
|
|
265
|
+
@cached_property
|
|
266
|
+
def is_const(self):
|
|
267
|
+
return "const" in self.annotation and not self.is_constexpr
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def default(self):
|
|
271
|
+
return self._param.default
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def has_default(self):
|
|
275
|
+
return self._param.default != inspect.Parameter.empty
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def compute_spec_key(v, align):
|
|
279
|
+
|
|
280
|
+
if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
|
|
281
|
+
return "D"
|
|
282
|
+
elif isinstance(v, int):
|
|
283
|
+
# bool is a subclass of int, so we don't check explicitly above.
|
|
284
|
+
if align and (v % 16 == 0):
|
|
285
|
+
return "D"
|
|
286
|
+
elif v == 1:
|
|
287
|
+
return "1"
|
|
288
|
+
return "N"
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
dtype2str = {}
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def mangle_type(arg, is_const=False):
|
|
295
|
+
|
|
296
|
+
if arg is None:
|
|
297
|
+
return "none"
|
|
298
|
+
elif isinstance(arg, bool):
|
|
299
|
+
return "i1"
|
|
300
|
+
elif isinstance(arg, int):
|
|
301
|
+
if -(2**31) <= arg and arg <= 2**31 - 1:
|
|
302
|
+
return "i32"
|
|
303
|
+
elif 2**63 <= arg and arg <= 2**64 - 1:
|
|
304
|
+
return "u64"
|
|
305
|
+
else:
|
|
306
|
+
return "i64"
|
|
307
|
+
elif isinstance(arg, float):
|
|
308
|
+
return "fp32"
|
|
309
|
+
elif hasattr(arg, "tma_desc_cpu_ptr"):
|
|
310
|
+
return "nvTmaDesc"
|
|
311
|
+
else:
|
|
312
|
+
# dtypes are hashable so we can memoize this mapping:
|
|
313
|
+
dsk = (arg.dtype, is_const)
|
|
314
|
+
res = dtype2str.get(dsk, None)
|
|
315
|
+
if res is None:
|
|
316
|
+
res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
|
|
317
|
+
dtype2str[dsk] = res
|
|
318
|
+
return res
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class KernelInterface(Generic[T]):
|
|
322
|
+
run: T
|
|
323
|
+
|
|
324
|
+
def __getitem__(self, grid) -> T:
|
|
325
|
+
"""
|
|
326
|
+
A JIT function is launched with: fn[grid](*args, **kwargs).
|
|
327
|
+
Hence JITFunction.__getitem__ returns a callable proxy that
|
|
328
|
+
memorizes the grid.
|
|
329
|
+
"""
|
|
330
|
+
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
|
|
331
|
+
# return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def serialize_specialization_data(name, signature, constants, attrs, options, key):
|
|
335
|
+
constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
|
|
336
|
+
import json
|
|
337
|
+
obj = {
|
|
338
|
+
'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options':
|
|
339
|
+
options.__dict__, 'key': key
|
|
340
|
+
}
|
|
341
|
+
serialized_obj = json.dumps(obj)
|
|
342
|
+
return serialized_obj
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def create_function_from_signature(sig, kparams, backend):
|
|
346
|
+
"""
|
|
347
|
+
Equivalent to sig.bind followed by apply_defaults. This generates a
|
|
348
|
+
native Python function (using exec) which can be memoized on a per-kernel
|
|
349
|
+
basis to avoid having to run these expensive functions -- which constitute
|
|
350
|
+
much of the kernel launch overhead -- every time we run the kernel.
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
assert len(sig.parameters) == len(kparams)
|
|
354
|
+
|
|
355
|
+
# Create the function argument list and the dict entries for the return statement
|
|
356
|
+
func_args = []
|
|
357
|
+
dict_entries = []
|
|
358
|
+
constexpr_vals = []
|
|
359
|
+
non_constexpr_vals = []
|
|
360
|
+
signature_types = []
|
|
361
|
+
specialisations = []
|
|
362
|
+
|
|
363
|
+
for ((name, sp), kp) in zip(sig.parameters.items(), kparams):
|
|
364
|
+
if sp.default is inspect.Parameter.empty:
|
|
365
|
+
func_args.append(name)
|
|
366
|
+
dict_entries.append(f"'{name}': {name}")
|
|
367
|
+
else:
|
|
368
|
+
func_args.append(f"{name}=default_{name}")
|
|
369
|
+
dict_entries.append(f"'{name}': {name}")
|
|
370
|
+
if kp.is_constexpr:
|
|
371
|
+
constexpr_vals.append(name)
|
|
372
|
+
else:
|
|
373
|
+
non_constexpr_vals.append(name)
|
|
374
|
+
if not kp.do_not_specialize:
|
|
375
|
+
if not kp.do_not_specialize_on_alignment:
|
|
376
|
+
specialisations.append('compute_spec_key(%s, align=True)' % name)
|
|
377
|
+
else:
|
|
378
|
+
specialisations.append('compute_spec_key(%s, align=False)' % name)
|
|
379
|
+
if kp.annotation_type:
|
|
380
|
+
signature_types.append('"%s"' % kp.annotation_type)
|
|
381
|
+
else:
|
|
382
|
+
signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False'))
|
|
383
|
+
|
|
384
|
+
cache_key = ''.join([x + ', ' for x in signature_types + specialisations])
|
|
385
|
+
constexpr_vals = ''.join([x + ', ' for x in constexpr_vals])
|
|
386
|
+
non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals])
|
|
387
|
+
|
|
388
|
+
func_args.append('**excess_kwargs')
|
|
389
|
+
|
|
390
|
+
# Join all arguments into a function definition string
|
|
391
|
+
args_str = ', '.join(func_args)
|
|
392
|
+
dict_str = ', '.join(dict_entries)
|
|
393
|
+
func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % (
|
|
394
|
+
args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals)
|
|
395
|
+
|
|
396
|
+
# Prepare defaults to be inserted into function namespace
|
|
397
|
+
func_namespace = {
|
|
398
|
+
f"default_{name}": param.default
|
|
399
|
+
for name, param in sig.parameters.items()
|
|
400
|
+
if param.default is not inspect.Parameter.empty
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
func_namespace['mangle_type'] = mangle_type
|
|
404
|
+
func_namespace['compute_spec_key'] = backend.compute_spec_key
|
|
405
|
+
|
|
406
|
+
# Execute the function string in func_namespace to create the function
|
|
407
|
+
exec(func_body, func_namespace)
|
|
408
|
+
|
|
409
|
+
# Extract the newly created function from the namespace
|
|
410
|
+
return func_namespace['dynamic_func']
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
type_canonicalisation_dict = {
|
|
414
|
+
"bool": "i1",
|
|
415
|
+
"float8e4nv": "fp8e4nv",
|
|
416
|
+
"float8e5": "fp8e5",
|
|
417
|
+
"float8e4b15": "fp8e4b15",
|
|
418
|
+
"float8_e4m3fn": "fp8e4nv",
|
|
419
|
+
"float8e4b8": "fp8e4b8",
|
|
420
|
+
"float8_e4m3fnuz": "fp8e4b8",
|
|
421
|
+
"float8_e5m2": "fp8e5",
|
|
422
|
+
"float8e5b16": "fp8e5b16",
|
|
423
|
+
"float8_e5m2fnuz": "fp8e5b16",
|
|
424
|
+
"float16": "fp16",
|
|
425
|
+
"bfloat16": "bf16",
|
|
426
|
+
"float32": "fp32",
|
|
427
|
+
"float64": "fp64",
|
|
428
|
+
"int8": "i8",
|
|
429
|
+
"int16": "i16",
|
|
430
|
+
"int32": "i32",
|
|
431
|
+
"int64": "i64",
|
|
432
|
+
"uint8": "u8",
|
|
433
|
+
"uint16": "u16",
|
|
434
|
+
"uint32": "u32",
|
|
435
|
+
"uint64": "u64",
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
for v in list(type_canonicalisation_dict.values()):
|
|
439
|
+
type_canonicalisation_dict[v] = v
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class JITFunction(KernelInterface[T]):
|
|
443
|
+
# Hook for inspecting compiled functions and modules
|
|
444
|
+
cache_hook = None
|
|
445
|
+
# Hook to signal that a kernel is done compiling and inspect compiled function.
|
|
446
|
+
# cache_hook will always be called before compilation and compiled_hook after.
|
|
447
|
+
compiled_hook = None
|
|
448
|
+
|
|
449
|
+
@staticmethod
|
|
450
|
+
def _key_of(arg):
|
|
451
|
+
if hasattr(arg, "dtype"):
|
|
452
|
+
return arg.dtype
|
|
453
|
+
elif isinstance(arg, bool):
|
|
454
|
+
return "i1"
|
|
455
|
+
elif isinstance(arg, int):
|
|
456
|
+
if -(2**31) <= arg and arg <= 2**31 - 1:
|
|
457
|
+
return "i32"
|
|
458
|
+
elif 2**63 <= arg and arg <= 2**64 - 1:
|
|
459
|
+
return "u64"
|
|
460
|
+
else:
|
|
461
|
+
return "i64"
|
|
462
|
+
elif isinstance(arg, float):
|
|
463
|
+
return "fp32"
|
|
464
|
+
elif arg is None:
|
|
465
|
+
return None
|
|
466
|
+
else:
|
|
467
|
+
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
|
|
468
|
+
|
|
469
|
+
@staticmethod
|
|
470
|
+
def _type_of(key, is_const=False):
|
|
471
|
+
# `None` is nullptr. Implicitly convert to *i8.
|
|
472
|
+
if key is None:
|
|
473
|
+
return "*i8"
|
|
474
|
+
elif isinstance(key, str):
|
|
475
|
+
return key
|
|
476
|
+
|
|
477
|
+
dtype_str = str(key).split(".")[-1]
|
|
478
|
+
dtype_str = type_canonicalisation_dict[dtype_str]
|
|
479
|
+
const_str = "*k" if is_const else "*"
|
|
480
|
+
return const_str + dtype_str
|
|
481
|
+
|
|
482
|
+
def _make_constants(self, constexpr_key):
|
|
483
|
+
constants = dict(zip(self.constexprs, constexpr_key))
|
|
484
|
+
return constants
|
|
485
|
+
|
|
486
|
+
def _call_hook(
|
|
487
|
+
self,
|
|
488
|
+
key,
|
|
489
|
+
signature,
|
|
490
|
+
device,
|
|
491
|
+
constants,
|
|
492
|
+
options,
|
|
493
|
+
configs,
|
|
494
|
+
is_warmup,
|
|
495
|
+
before,
|
|
496
|
+
):
|
|
497
|
+
hook = JITFunction.cache_hook if before else JITFunction.compiled_hook
|
|
498
|
+
if hook is None:
|
|
499
|
+
return False
|
|
500
|
+
|
|
501
|
+
name = self.fn.__name__
|
|
502
|
+
module = self.fn.__module__
|
|
503
|
+
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
|
|
504
|
+
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
|
|
505
|
+
|
|
506
|
+
class JitFunctionInfo:
|
|
507
|
+
|
|
508
|
+
def __init__(self, module, name, jit_function):
|
|
509
|
+
self.module = module
|
|
510
|
+
self.name = name
|
|
511
|
+
self.jit_function = jit_function
|
|
512
|
+
pass
|
|
513
|
+
|
|
514
|
+
specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key)
|
|
515
|
+
|
|
516
|
+
kwargs = {
|
|
517
|
+
'signature': signature,
|
|
518
|
+
'device': device,
|
|
519
|
+
'constants': constants,
|
|
520
|
+
'num_warps': options.num_warps,
|
|
521
|
+
'num_ctas': options.num_ctas,
|
|
522
|
+
'num_stages': options.num_stages,
|
|
523
|
+
'enable_fp_fusion': options.enable_fp_fusion,
|
|
524
|
+
'extern_libs': options.extern_libs,
|
|
525
|
+
'configs': configs,
|
|
526
|
+
'specialization_data': specialization_data,
|
|
527
|
+
'is_warmup': is_warmup,
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
return hook(
|
|
531
|
+
key=key,
|
|
532
|
+
repr=repr,
|
|
533
|
+
fn=JitFunctionInfo(module, name, self),
|
|
534
|
+
compile={"key": key, **kwargs},
|
|
535
|
+
is_manual_warmup=is_warmup,
|
|
536
|
+
already_compiled=False,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
def add_pre_run_hook(self, hook):
|
|
540
|
+
'''
|
|
541
|
+
Add a hook that will be executed prior to the execution of run
|
|
542
|
+
function with args and kwargs passed into the kernel
|
|
543
|
+
'''
|
|
544
|
+
assert callable(hook)
|
|
545
|
+
self.pre_run_hooks.append(hook)
|
|
546
|
+
|
|
547
|
+
def create_binder(self, backend):
|
|
548
|
+
"""
|
|
549
|
+
Precompute as much as possible.
|
|
550
|
+
"""
|
|
551
|
+
from ..compiler import CompiledKernel, compile, ASTSource, make_backend
|
|
552
|
+
self.CompiledKernel = CompiledKernel
|
|
553
|
+
self.compile = compile
|
|
554
|
+
self.ASTSource = ASTSource
|
|
555
|
+
self.make_backend = make_backend
|
|
556
|
+
self.binder = create_function_from_signature(self.signature, self.params, backend)
|
|
557
|
+
self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
|
|
558
|
+
self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
|
|
559
|
+
self.specialised_indices = [
|
|
560
|
+
i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr)
|
|
561
|
+
]
|
|
562
|
+
|
|
563
|
+
def run(self, *args, grid, warmup, **kwargs):
|
|
564
|
+
kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1"
|
|
565
|
+
|
|
566
|
+
# parse options
|
|
567
|
+
from ..compiler import make_backend
|
|
568
|
+
device = driver.active.get_current_device()
|
|
569
|
+
stream = driver.active.get_current_stream(device)
|
|
570
|
+
target = driver.active.get_current_target()
|
|
571
|
+
backend = make_backend(target)
|
|
572
|
+
|
|
573
|
+
# Execute pre run hooks with args and kwargs
|
|
574
|
+
for hook in self.pre_run_hooks:
|
|
575
|
+
hook(*args, **kwargs)
|
|
576
|
+
|
|
577
|
+
if self.binder is None:
|
|
578
|
+
self.create_binder(backend)
|
|
579
|
+
|
|
580
|
+
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
|
|
581
|
+
|
|
582
|
+
# compute cache key
|
|
583
|
+
key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
|
|
584
|
+
kernel = self.cache[device].get(key, None)
|
|
585
|
+
|
|
586
|
+
if kernel is None:
|
|
587
|
+
# Kernel is not cached; we have to compile.
|
|
588
|
+
options = backend.parse_options(kwargs)
|
|
589
|
+
|
|
590
|
+
# deprecated arguments
|
|
591
|
+
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
|
|
592
|
+
assert "device" not in kwargs, "device option is deprecated; current device will be used"
|
|
593
|
+
assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
|
|
594
|
+
for k in excess_kwargs:
|
|
595
|
+
if k not in options.__dict__:
|
|
596
|
+
raise KeyError("Keyword argument %s was specified but unrecognised" % k)
|
|
597
|
+
|
|
598
|
+
bound_vals = tuple(bound_args.values())
|
|
599
|
+
|
|
600
|
+
# `None` is nullptr. Implicitly convert to *i8. This needs to be
|
|
601
|
+
# done here rather than when we build the signature as otherwise
|
|
602
|
+
# the kernel cache key could not distinguish between byte pointers
|
|
603
|
+
# and None arguments, resulting in a downstream mismatch:
|
|
604
|
+
sigkeys = [self.params[i].name for i in self.non_constexpr_indices]
|
|
605
|
+
sigvals = sig_and_spec[:len(sigkeys)]
|
|
606
|
+
signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)}
|
|
607
|
+
|
|
608
|
+
configs = (backend.get_attrs_descriptor(self.params, bound_vals), )
|
|
609
|
+
constant_params = configs[0].get_constants()
|
|
610
|
+
constants = {
|
|
611
|
+
p.name: v
|
|
612
|
+
for (v, p) in zip(bound_vals, self.params)
|
|
613
|
+
if p.is_constexpr or (p.num in constant_params) or v is None
|
|
614
|
+
}
|
|
615
|
+
for i, arg in constants.items():
|
|
616
|
+
if callable(arg):
|
|
617
|
+
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
|
618
|
+
|
|
619
|
+
if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True):
|
|
620
|
+
return None
|
|
621
|
+
# compile the kernel
|
|
622
|
+
src = self.ASTSource(self, signature, constants, configs[0])
|
|
623
|
+
kernel = self.compile(
|
|
624
|
+
src,
|
|
625
|
+
target=target,
|
|
626
|
+
options=options.__dict__,
|
|
627
|
+
)
|
|
628
|
+
self.cache[device][key] = kernel
|
|
629
|
+
self._call_hook(key, signature, device, constants, options, configs, warmup, before=False)
|
|
630
|
+
|
|
631
|
+
# Check that used global values have not changed.
|
|
632
|
+
not_present = object()
|
|
633
|
+
for (name, _), (val, globals_dict) in self.used_global_vals.items():
|
|
634
|
+
if (newVal := globals_dict.get(name, not_present)) != val:
|
|
635
|
+
raise RuntimeError(
|
|
636
|
+
f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")
|
|
637
|
+
|
|
638
|
+
if not warmup:
|
|
639
|
+
# canonicalize grid
|
|
640
|
+
assert grid is not None
|
|
641
|
+
if callable(grid):
|
|
642
|
+
# Arguments are passed as a dict to `grid`, by contract.
|
|
643
|
+
# TODO(jlebar): In the new launch API, pass the compiler flags as a
|
|
644
|
+
# second parameter to `grid`.
|
|
645
|
+
grid = grid(bound_args)
|
|
646
|
+
grid_size = len(grid)
|
|
647
|
+
grid_0 = grid[0]
|
|
648
|
+
grid_1 = grid[1] if grid_size > 1 else 1
|
|
649
|
+
grid_2 = grid[2] if grid_size > 2 else 1
|
|
650
|
+
|
|
651
|
+
# launch kernel
|
|
652
|
+
launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
|
|
653
|
+
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
|
|
654
|
+
self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
|
|
655
|
+
return kernel
|
|
656
|
+
|
|
657
|
+
def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
|
|
658
|
+
noinline=None, repr=None, launch_metadata=None):
|
|
659
|
+
do_not_specialize = do_not_specialize if do_not_specialize else []
|
|
660
|
+
do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
|
|
661
|
+
|
|
662
|
+
self.fn = fn
|
|
663
|
+
self.module = fn.__module__
|
|
664
|
+
self.version = version
|
|
665
|
+
self.signature = inspect.signature(fn)
|
|
666
|
+
self.do_not_specialize = do_not_specialize
|
|
667
|
+
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
|
|
668
|
+
self.starting_line_number = inspect.getsourcelines(fn)[1]
|
|
669
|
+
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
|
|
670
|
+
self.launch_metadata = launch_metadata
|
|
671
|
+
|
|
672
|
+
self.binder = None
|
|
673
|
+
|
|
674
|
+
self.params = []
|
|
675
|
+
for i, param in enumerate(self.signature.parameters.values()):
|
|
676
|
+
dns = i in do_not_specialize or param.name in do_not_specialize
|
|
677
|
+
dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
|
|
678
|
+
self.params.append(KernelParam(i, param, dns, dns_oa))
|
|
679
|
+
|
|
680
|
+
# function source code (without decorators)
|
|
681
|
+
self.src = textwrap.dedent(inspect.getsource(fn))
|
|
682
|
+
self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
|
|
683
|
+
# cache of just-in-time compiled kernels
|
|
684
|
+
self.cache = defaultdict(dict)
|
|
685
|
+
self.hash = None
|
|
686
|
+
|
|
687
|
+
# Map of global variables used by the function and any functions it
|
|
688
|
+
# transitively calls, plus their values. The values are collected when
|
|
689
|
+
# the function is first compiled. Then every time we run the function,
|
|
690
|
+
# we check that the values of the globals match what's expected,
|
|
691
|
+
# otherwise we raise an error.
|
|
692
|
+
#
|
|
693
|
+
# Different functions can have different __globals__ maps, so the map
|
|
694
|
+
# key is actually (var name, id(__globals__)), and the map value is
|
|
695
|
+
# (value, __globals__).
|
|
696
|
+
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
|
|
697
|
+
|
|
698
|
+
# JITFunction can be instantiated as kernel
|
|
699
|
+
# when called with a grid using __getitem__
|
|
700
|
+
self.kernel = None
|
|
701
|
+
self.noinline = noinline
|
|
702
|
+
|
|
703
|
+
# TODO(jlebar): Remove uses of these fields outside this file, then
|
|
704
|
+
# remove the fields here.
|
|
705
|
+
self.arg_names = [p.name for p in self.params]
|
|
706
|
+
self.constexprs = [p.num for p in self.params if p.is_constexpr]
|
|
707
|
+
|
|
708
|
+
# Hooks that will be called prior to executing "run"
|
|
709
|
+
self.pre_run_hooks = []
|
|
710
|
+
|
|
711
|
+
# reuse docs of wrapped function
|
|
712
|
+
self.__doc__ = fn.__doc__
|
|
713
|
+
self.__name__ = fn.__name__
|
|
714
|
+
self.__globals__ = fn.__globals__
|
|
715
|
+
self.__module__ = fn.__module__
|
|
716
|
+
|
|
717
|
+
@property
|
|
718
|
+
def cache_key(self):
|
|
719
|
+
# TODO : hash should be attribute of `self`
|
|
720
|
+
if self.hash is None:
|
|
721
|
+
dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src)
|
|
722
|
+
dependencies_finder.visit(self.parse())
|
|
723
|
+
self.hash = dependencies_finder.ret + str(self.starting_line_number)
|
|
724
|
+
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
|
|
725
|
+
return self.hash
|
|
726
|
+
|
|
727
|
+
def warmup(self, *args, grid, **kwargs):
|
|
728
|
+
return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
|
|
729
|
+
|
|
730
|
+
def preload(self, specialization_data):
|
|
731
|
+
from ..compiler import compile, ASTSource
|
|
732
|
+
from triton.backends.compiler import AttrsDescriptor
|
|
733
|
+
import json
|
|
734
|
+
import triton.language as tl
|
|
735
|
+
device = driver.active.get_current_device()
|
|
736
|
+
deserialized_obj = json.loads(specialization_data)
|
|
737
|
+
if deserialized_obj['name'] != self.fn.__name__:
|
|
738
|
+
raise RuntimeError(
|
|
739
|
+
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
|
|
740
|
+
constants = {
|
|
741
|
+
key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
|
|
742
|
+
for key, value in deserialized_obj['constants'].items()
|
|
743
|
+
}
|
|
744
|
+
signature = dict(deserialized_obj['signature'].items())
|
|
745
|
+
src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs']))
|
|
746
|
+
options = {
|
|
747
|
+
key: tuple(value) if isinstance(value, list) else value
|
|
748
|
+
for key, value in deserialized_obj['options'].items()
|
|
749
|
+
}
|
|
750
|
+
key = deserialized_obj['key']
|
|
751
|
+
kernel = compile(src, None, options)
|
|
752
|
+
self.cache[device][key] = kernel
|
|
753
|
+
return kernel
|
|
754
|
+
|
|
755
|
+
# we do not parse `src` in the constructor because
|
|
756
|
+
# the user might want to monkey-patch self.src dynamically.
|
|
757
|
+
# Our unit tests do this, for example.
|
|
758
|
+
def parse(self):
|
|
759
|
+
tree = ast.parse(self.src)
|
|
760
|
+
assert isinstance(tree, ast.Module)
|
|
761
|
+
assert len(tree.body) == 1
|
|
762
|
+
assert isinstance(tree.body[0], ast.FunctionDef)
|
|
763
|
+
return tree
|
|
764
|
+
|
|
765
|
+
def __call__(self, *args, **kwargs):
|
|
766
|
+
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
|
767
|
+
|
|
768
|
+
def __setattr__(self, name, value):
|
|
769
|
+
super(JITFunction, self).__setattr__(name, value)
|
|
770
|
+
# - when `.src` attribute is set, cache path needs
|
|
771
|
+
# to be reinitialized
|
|
772
|
+
if name == "src":
|
|
773
|
+
self.hash = None
|
|
774
|
+
|
|
775
|
+
def __repr__(self):
|
|
776
|
+
return f"JITFunction({self.module}:{self.fn.__name__})"
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
# -----------------------------------------------------------------------------
|
|
780
|
+
# `jit` decorator
|
|
781
|
+
# -----------------------------------------------------------------------------
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
@overload
|
|
785
|
+
def jit(fn: T) -> JITFunction[T]:
|
|
786
|
+
...
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
@overload
|
|
790
|
+
def jit(
|
|
791
|
+
*,
|
|
792
|
+
version=None,
|
|
793
|
+
repr: Optional[Callable] = None,
|
|
794
|
+
launch_metadata: Optional[Callable] = None,
|
|
795
|
+
do_not_specialize: Optional[Iterable[int]] = None,
|
|
796
|
+
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
|
|
797
|
+
debug: Optional[bool] = None,
|
|
798
|
+
noinline: Optional[bool] = None,
|
|
799
|
+
) -> Callable[[T], JITFunction[T]]:
|
|
800
|
+
...
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
def jit(
|
|
804
|
+
fn: Optional[T] = None,
|
|
805
|
+
*,
|
|
806
|
+
version=None,
|
|
807
|
+
repr: Optional[Callable] = None,
|
|
808
|
+
launch_metadata: Optional[Callable] = None,
|
|
809
|
+
do_not_specialize: Optional[Iterable[int]] = None,
|
|
810
|
+
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
|
|
811
|
+
debug: Optional[bool] = None,
|
|
812
|
+
noinline: Optional[bool] = None,
|
|
813
|
+
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
|
814
|
+
"""
|
|
815
|
+
Decorator for JIT-compiling a function using the Triton compiler.
|
|
816
|
+
|
|
817
|
+
:note: When a jit'd function is called, arguments are
|
|
818
|
+
implicitly converted to pointers if they have a :code:`.data_ptr()` method
|
|
819
|
+
and a `.dtype` attribute.
|
|
820
|
+
|
|
821
|
+
:note: This function will be compiled and run on the GPU. It will only have access to:
|
|
822
|
+
|
|
823
|
+
* python primitives,
|
|
824
|
+
* builtins within the triton package,
|
|
825
|
+
* arguments to this function,
|
|
826
|
+
* other jit'd functions
|
|
827
|
+
|
|
828
|
+
:param fn: the function to be jit-compiled
|
|
829
|
+
:type fn: Callable
|
|
830
|
+
"""
|
|
831
|
+
|
|
832
|
+
def decorator(fn: T) -> JITFunction[T]:
|
|
833
|
+
assert callable(fn)
|
|
834
|
+
if os.getenv("TRITON_INTERPRET", "0") == "1":
|
|
835
|
+
from .interpreter import InterpretedFunction
|
|
836
|
+
return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
|
|
837
|
+
do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,
|
|
838
|
+
noinline=noinline, repr=repr, launch_metadata=launch_metadata)
|
|
839
|
+
else:
|
|
840
|
+
return JITFunction(
|
|
841
|
+
fn,
|
|
842
|
+
version=version,
|
|
843
|
+
do_not_specialize=do_not_specialize,
|
|
844
|
+
do_not_specialize_on_alignment=do_not_specialize_on_alignment,
|
|
845
|
+
debug=debug,
|
|
846
|
+
noinline=noinline,
|
|
847
|
+
repr=repr,
|
|
848
|
+
launch_metadata=launch_metadata,
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
if fn is not None:
|
|
852
|
+
return decorator(fn)
|
|
853
|
+
|
|
854
|
+
else:
|
|
855
|
+
return decorator
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
# -----------------------------------------------------------------------------
|
|
859
|
+
# Utilities for mocking tensors
|
|
860
|
+
# -----------------------------------------------------------------------------
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
class MockTensor:
|
|
864
|
+
"""
|
|
865
|
+
Can be used in place of real tensors when calling:
|
|
866
|
+
kernel.warmup(MockTensor(torch.float32), ...)
|
|
867
|
+
"""
|
|
868
|
+
|
|
869
|
+
@staticmethod
|
|
870
|
+
def wrap_dtype(arg):
|
|
871
|
+
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
|
|
872
|
+
return MockTensor(arg)
|
|
873
|
+
return arg
|
|
874
|
+
|
|
875
|
+
def __init__(self, dtype):
|
|
876
|
+
self.dtype = dtype
|
|
877
|
+
|
|
878
|
+
@staticmethod
|
|
879
|
+
def data_ptr():
|
|
880
|
+
return 0 # optimistically assumes multiple of 16
|
|
881
|
+
|
|
882
|
+
@staticmethod
|
|
883
|
+
def ptr_range():
|
|
884
|
+
return 0 # optimistically assumes 32 bit pointer range
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
class TensorWrapper:
|
|
888
|
+
|
|
889
|
+
def __init__(self, base, dtype):
|
|
890
|
+
self.dtype = dtype
|
|
891
|
+
self.base = base
|
|
892
|
+
self.data = base.data
|
|
893
|
+
self.device = base.device
|
|
894
|
+
self.shape = self.base.shape
|
|
895
|
+
|
|
896
|
+
def data_ptr(self):
|
|
897
|
+
return self.base.data_ptr()
|
|
898
|
+
|
|
899
|
+
def stride(self, i):
|
|
900
|
+
return self.base.stride(i)
|
|
901
|
+
|
|
902
|
+
def __str__(self) -> str:
|
|
903
|
+
return f"TensorWrapper[{self.dtype}]({self.base})"
|
|
904
|
+
|
|
905
|
+
def element_size(self):
|
|
906
|
+
return self.base.element_size()
|
|
907
|
+
|
|
908
|
+
def cpu(self):
|
|
909
|
+
return TensorWrapper(self.base.cpu(), self.dtype)
|
|
910
|
+
|
|
911
|
+
def copy_(self, other):
|
|
912
|
+
self.base.copy_(other.base)
|
|
913
|
+
|
|
914
|
+
def clone(self):
|
|
915
|
+
return TensorWrapper(self.base.clone(), self.dtype)
|
|
916
|
+
|
|
917
|
+
def to(self, device):
|
|
918
|
+
return TensorWrapper(self.base.to(device), self.dtype)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
def reinterpret(tensor, dtype):
|
|
922
|
+
if isinstance(tensor, TensorWrapper):
|
|
923
|
+
if dtype == tensor.base.dtype:
|
|
924
|
+
# Reinterpreting to the original interpretation; return the base.
|
|
925
|
+
return tensor.base
|
|
926
|
+
else:
|
|
927
|
+
# Reinterpreting a wrapped tensor to a different type.
|
|
928
|
+
return TensorWrapper(tensor.base, dtype)
|
|
929
|
+
elif hasattr(tensor, "data_ptr"):
|
|
930
|
+
# A new wrapper is needed around an unwrapped tensor.
|
|
931
|
+
return TensorWrapper(tensor, dtype)
|
|
932
|
+
else:
|
|
933
|
+
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
def get_jit_fn_file_line(fn):
|
|
937
|
+
base_fn = fn
|
|
938
|
+
while not isinstance(base_fn, JITFunction):
|
|
939
|
+
base_fn = base_fn.fn
|
|
940
|
+
file_name = base_fn.fn.__code__.co_filename
|
|
941
|
+
lines, begin_line = inspect.getsourcelines(base_fn.fn)
|
|
942
|
+
# Match the following pattern:
|
|
943
|
+
# @triton.autotune(...) <- foo.__code__.co_firstlineno
|
|
944
|
+
# @triton.heuristics(...)
|
|
945
|
+
# @triton.jit
|
|
946
|
+
# def foo(...): <- this line is the first line
|
|
947
|
+
for idx, line in enumerate(lines):
|
|
948
|
+
if line.strip().startswith("def "):
|
|
949
|
+
begin_line += idx
|
|
950
|
+
break
|
|
951
|
+
return file_name, begin_line
|