triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/runtime/jit.py
CHANGED
|
@@ -1,19 +1,28 @@
|
|
|
1
1
|
from __future__ import annotations, division
|
|
2
2
|
import ast
|
|
3
|
+
import copy
|
|
3
4
|
import hashlib
|
|
4
5
|
import inspect
|
|
5
6
|
import itertools
|
|
6
|
-
import
|
|
7
|
+
import threading
|
|
7
8
|
import re
|
|
8
9
|
import textwrap
|
|
9
10
|
from collections import defaultdict
|
|
11
|
+
from dataclasses import dataclass
|
|
10
12
|
from functools import cached_property
|
|
11
13
|
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
|
|
12
|
-
|
|
14
|
+
|
|
15
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
13
16
|
from types import ModuleType
|
|
14
|
-
from ..
|
|
17
|
+
from .. import knobs
|
|
18
|
+
from .driver import driver
|
|
19
|
+
from . import _async_compile
|
|
20
|
+
from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype
|
|
21
|
+
from .cache import get_cache_key
|
|
22
|
+
from triton._C.libtriton import get_cache_invalidating_env_vars
|
|
15
23
|
|
|
16
|
-
TRITON_MODULE =
|
|
24
|
+
TRITON_MODULE = "triton.language"
|
|
25
|
+
GLUON_MODULE = "triton.experimental.gluon.language"
|
|
17
26
|
|
|
18
27
|
T = TypeVar("T")
|
|
19
28
|
|
|
@@ -34,13 +43,14 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
34
43
|
otherwise we could recompile).
|
|
35
44
|
"""
|
|
36
45
|
|
|
37
|
-
def __init__(self, name, globals, src) -> None:
|
|
46
|
+
def __init__(self, name, globals, nonlocals, src) -> None:
|
|
38
47
|
super().__init__()
|
|
39
48
|
self.name = name
|
|
40
49
|
self.hasher = hashlib.sha256(src.encode("utf-8"))
|
|
41
50
|
|
|
42
51
|
# This function's __globals__ dict.
|
|
43
52
|
self.globals = globals
|
|
53
|
+
self.nonlocals = nonlocals
|
|
44
54
|
|
|
45
55
|
# Python builtins that can be accessed from Triton kernels.
|
|
46
56
|
self.supported_python_builtins = {
|
|
@@ -55,6 +65,12 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
55
65
|
'print',
|
|
56
66
|
'range',
|
|
57
67
|
}
|
|
68
|
+
self.supported_modules = {
|
|
69
|
+
GLUON_MODULE,
|
|
70
|
+
TRITON_MODULE,
|
|
71
|
+
"copy",
|
|
72
|
+
"math",
|
|
73
|
+
}
|
|
58
74
|
|
|
59
75
|
# used_global_vals tells us which global variables are used by this
|
|
60
76
|
# function and all those it transitively calls, plus the values of those
|
|
@@ -81,22 +97,56 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
81
97
|
return module.startswith(TRITON_MODULE)
|
|
82
98
|
|
|
83
99
|
def _update_hash(self, func):
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
+
assert isinstance(func, JITCallable)
|
|
101
|
+
# Merge our used_global_vals with those of the called function,
|
|
102
|
+
# after checking that all overlapping values are consistent.
|
|
103
|
+
for k in self.used_global_vals.keys() & func.used_global_vals.keys():
|
|
104
|
+
var_name, _ = k
|
|
105
|
+
v1, _ = self.used_global_vals[k]
|
|
106
|
+
v2, _ = func.used_global_vals[k]
|
|
107
|
+
if v1 != v2:
|
|
108
|
+
raise RuntimeError(
|
|
109
|
+
f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
|
|
110
|
+
)
|
|
111
|
+
self.used_global_vals.update(func.used_global_vals)
|
|
112
|
+
# update hash
|
|
113
|
+
func_key = func.cache_key
|
|
114
|
+
func_key += str(getattr(func, "noinline", False))
|
|
115
|
+
self.hasher.update(func_key.encode("utf-8"))
|
|
116
|
+
|
|
117
|
+
def record_reference(self, val, var_dict=None, name=None):
|
|
118
|
+
from ..language.core import constexpr
|
|
119
|
+
# Only keep track of "interesting" global variables, that non-evil users
|
|
120
|
+
# might change. Don't consider functions, modules, builtins, etc. This
|
|
121
|
+
# helps keep the list of vars we have to check small.
|
|
122
|
+
if val is None or type(val) is ModuleType:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
if getattr(val, "__triton_builtin__", False):
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
# Stubs that aren't real functions
|
|
129
|
+
if getattr(val, "__module__", "") == "triton.language.extra.libdevice":
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
if isinstance(val, JITCallable):
|
|
133
|
+
self._update_hash(val)
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
if callable(val) and not isinstance(val, type) and not isinstance(val, constexpr):
|
|
137
|
+
raise RuntimeError(f"Unsupported function referenced: {val}")
|
|
138
|
+
|
|
139
|
+
# Python default arguments are resolved only once, when the
|
|
140
|
+
# function is defined. So if you do `foo(a=A)` and the value of
|
|
141
|
+
# A changes, foo will still use the old value of A.
|
|
142
|
+
# It would be pretty evil if someone did `import x` and then
|
|
143
|
+
# `x = blah`.
|
|
144
|
+
if self.visiting_arg_default_value:
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
if var_dict is not None:
|
|
148
|
+
self.used_global_vals[(name, id(var_dict))] = (copy.deepcopy(val), var_dict)
|
|
149
|
+
return
|
|
100
150
|
|
|
101
151
|
def visit_Name(self, node):
|
|
102
152
|
if type(node.ctx) is ast.Store:
|
|
@@ -106,26 +156,20 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
106
156
|
# The global name is hidden by the local name.
|
|
107
157
|
return None
|
|
108
158
|
|
|
109
|
-
|
|
159
|
+
def name_lookup(name):
|
|
160
|
+
val = self.globals.get(name, None)
|
|
161
|
+
if val is not None:
|
|
162
|
+
return val, self.globals
|
|
163
|
+
val = self.nonlocals.get(name, None)
|
|
164
|
+
if val is not None:
|
|
165
|
+
return val, self.nonlocals
|
|
166
|
+
return None, None
|
|
110
167
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# function is defined. So if you do `foo(a=A)` and the value of
|
|
117
|
-
# A changes, foo will still use the old value of A.
|
|
118
|
-
and not self.visiting_arg_default_value
|
|
119
|
-
# It would be pretty evil if someone did `import x` and then
|
|
120
|
-
# `x = blah`.
|
|
121
|
-
and type(val) is not ModuleType
|
|
122
|
-
# It would be pretty evil if we used function `foo` inside of
|
|
123
|
-
# `bar` and then someone did `foo = baz`.
|
|
124
|
-
and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
|
|
125
|
-
and node.id not in self.supported_python_builtins):
|
|
126
|
-
self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals)
|
|
127
|
-
|
|
128
|
-
self._update_hash(val)
|
|
168
|
+
val, var_dict = name_lookup(node.id)
|
|
169
|
+
if node.id in self.supported_python_builtins:
|
|
170
|
+
return val
|
|
171
|
+
|
|
172
|
+
self.record_reference(val, var_dict, node.id)
|
|
129
173
|
return val
|
|
130
174
|
|
|
131
175
|
def visit_Tuple(self, node):
|
|
@@ -137,10 +181,11 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
137
181
|
lhs = self.visit(node.value)
|
|
138
182
|
while isinstance(lhs, ast.Attribute):
|
|
139
183
|
lhs = self.visit(lhs.value)
|
|
140
|
-
|
|
184
|
+
lhs_name = getattr(lhs, "__name__", "")
|
|
185
|
+
if lhs is None or lhs_name in self.supported_modules:
|
|
141
186
|
return None
|
|
142
187
|
ret = getattr(lhs, node.attr)
|
|
143
|
-
self.
|
|
188
|
+
self.record_reference(ret)
|
|
144
189
|
return ret
|
|
145
190
|
|
|
146
191
|
def visit_FunctionDef(self, node):
|
|
@@ -221,11 +266,29 @@ class DependenciesFinder(ast.NodeVisitor):
|
|
|
221
266
|
|
|
222
267
|
|
|
223
268
|
def _normalize_ty(ty) -> str:
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
269
|
+
import triton.language.core as core
|
|
270
|
+
if isinstance(ty, str):
|
|
271
|
+
ty = ty.strip()
|
|
272
|
+
if ty.startswith("const "):
|
|
273
|
+
ty = ty.removeprefix("const")
|
|
274
|
+
ty = _normalize_ty(ty)
|
|
275
|
+
assert ty.startswith("*")
|
|
276
|
+
return "*k" + ty[1:]
|
|
277
|
+
if ty.endswith("*"):
|
|
278
|
+
return "*" + _normalize_ty(ty[:-1])
|
|
279
|
+
if ty.startswith("*"):
|
|
280
|
+
return "*" + _normalize_ty(ty[1:])
|
|
281
|
+
if ty.startswith("tl."):
|
|
282
|
+
return _normalize_ty(ty.removeprefix("tl."))
|
|
283
|
+
elif isinstance(ty, core.pointer_type):
|
|
284
|
+
return f"*{_normalize_ty(ty.element_ty)}"
|
|
285
|
+
elif isinstance(ty, core.dtype):
|
|
286
|
+
ty = ty.name
|
|
287
|
+
elif isinstance(ty, type):
|
|
288
|
+
ty = ty.__name__
|
|
289
|
+
else:
|
|
290
|
+
ty = str(ty)
|
|
291
|
+
return type_canonicalisation_dict.get(ty.replace("_t", ""), ty)
|
|
229
292
|
|
|
230
293
|
|
|
231
294
|
class KernelParam:
|
|
@@ -243,20 +306,20 @@ class KernelParam:
|
|
|
243
306
|
return self._param.name
|
|
244
307
|
|
|
245
308
|
@cached_property
|
|
246
|
-
def annotation(self):
|
|
309
|
+
def annotation(self) -> str:
|
|
247
310
|
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
|
|
248
311
|
return ""
|
|
249
312
|
return _normalize_ty(self._param.annotation)
|
|
250
313
|
|
|
251
314
|
@cached_property
|
|
252
|
-
def annotation_type(self):
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
if
|
|
259
|
-
return
|
|
315
|
+
def annotation_type(self) -> str:
|
|
316
|
+
a = self.annotation
|
|
317
|
+
if a.startswith("*k"):
|
|
318
|
+
a = a[2:]
|
|
319
|
+
elif a.startswith("*"):
|
|
320
|
+
a = a[1:]
|
|
321
|
+
if a in set(type_canonicalisation_dict.values()):
|
|
322
|
+
return self.annotation
|
|
260
323
|
return ""
|
|
261
324
|
|
|
262
325
|
@cached_property
|
|
@@ -265,7 +328,9 @@ class KernelParam:
|
|
|
265
328
|
|
|
266
329
|
@cached_property
|
|
267
330
|
def is_const(self):
|
|
268
|
-
|
|
331
|
+
if self.is_constexpr:
|
|
332
|
+
return False
|
|
333
|
+
return "const" in self.annotation or self.annotation.startswith("*k")
|
|
269
334
|
|
|
270
335
|
@property
|
|
271
336
|
def default(self):
|
|
@@ -280,22 +345,16 @@ dtype2str = {}
|
|
|
280
345
|
specialize_impl_cache = []
|
|
281
346
|
|
|
282
347
|
|
|
283
|
-
def create_specialize_impl():
|
|
284
|
-
if specialize_impl_cache:
|
|
285
|
-
return specialize_impl_cache[-1]
|
|
348
|
+
def create_specialize_impl(specialize_extra):
|
|
286
349
|
|
|
287
350
|
from ..language import constexpr
|
|
351
|
+
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
|
|
288
352
|
|
|
289
|
-
def specialize_impl(arg,
|
|
290
|
-
|
|
353
|
+
def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
|
|
291
354
|
if arg is None:
|
|
292
355
|
return ("constexpr", None)
|
|
293
|
-
elif isinstance(arg, JITFunction):
|
|
294
|
-
return ("constexpr", arg.cache_key)
|
|
295
|
-
elif isinstance(arg, constexpr):
|
|
296
|
-
return ("constexpr", arg)
|
|
297
356
|
elif isinstance(arg, bool):
|
|
298
|
-
return ("
|
|
357
|
+
return ("u1", None)
|
|
299
358
|
elif isinstance(arg, int):
|
|
300
359
|
key = specialize_extra(arg, "int", align=align) if specialize_value else None
|
|
301
360
|
if arg == 1 and specialize_value:
|
|
@@ -308,31 +367,44 @@ def create_specialize_impl():
|
|
|
308
367
|
return ("i64", key)
|
|
309
368
|
elif isinstance(arg, float):
|
|
310
369
|
return ("fp32", None)
|
|
311
|
-
elif hasattr(arg, "
|
|
312
|
-
return ("nvTmaDesc", None)
|
|
313
|
-
elif isinstance(arg, tuple):
|
|
314
|
-
spec = [specialize_impl(x, specialize_extra) for x in arg]
|
|
315
|
-
make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
|
|
316
|
-
tys = make_tuple([x[0] for x in spec])
|
|
317
|
-
keys = make_tuple([x[1] for x in spec])
|
|
318
|
-
return (tys, keys)
|
|
319
|
-
else:
|
|
370
|
+
elif hasattr(arg, "data_ptr"):
|
|
320
371
|
# dtypes are hashable so we can memoize this mapping:
|
|
321
372
|
dsk = (arg.dtype, is_const)
|
|
322
373
|
res = dtype2str.get(dsk, None)
|
|
323
374
|
if res is None:
|
|
324
|
-
res = ("*k" if dsk[1] else "*") +
|
|
375
|
+
res = ("*k" if dsk[1] else "*") + canonicalize_dtype(dsk[0])
|
|
325
376
|
dtype2str[dsk] = res
|
|
326
377
|
key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
|
|
327
378
|
return (res, key)
|
|
379
|
+
elif isinstance(arg, JITCallable):
|
|
380
|
+
return ("constexpr", arg.cache_key)
|
|
381
|
+
elif isinstance(arg, constexpr):
|
|
382
|
+
return ("constexpr", arg)
|
|
383
|
+
elif isinstance(arg, tuple):
|
|
384
|
+
spec = [specialize_impl(x) for x in arg]
|
|
385
|
+
make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
|
|
386
|
+
tys = make_tuple([x[0] for x in spec])
|
|
387
|
+
keys = make_tuple([x[1] for x in spec])
|
|
388
|
+
return (tys, keys)
|
|
389
|
+
elif isinstance(arg, TensorDescriptor):
|
|
390
|
+
assert hasattr(arg.base, "data_ptr")
|
|
391
|
+
inner = canonicalize_dtype(arg.base.dtype)
|
|
392
|
+
return (f"tensordesc<{inner}{list(arg.block_shape)}>", None)
|
|
393
|
+
elif isinstance(arg, GluonTensorDescriptor):
|
|
394
|
+
assert hasattr(arg.base, "data_ptr")
|
|
395
|
+
inner = canonicalize_dtype(arg.base.dtype)
|
|
396
|
+
return (f"tensordesc<{inner}{list(arg.block_shape)},{arg.layout!r}>", None)
|
|
397
|
+
else:
|
|
398
|
+
raise TypeError("Unsupported type: %s" % type(arg))
|
|
328
399
|
|
|
329
|
-
specialize_impl_cache.append(specialize_impl)
|
|
330
400
|
return specialize_impl
|
|
331
401
|
|
|
332
402
|
|
|
333
403
|
def mangle_type(arg, specialize=False):
|
|
334
|
-
|
|
335
|
-
|
|
404
|
+
if len(specialize_impl_cache) == 0:
|
|
405
|
+
specialize_impl_cache.append(create_specialize_impl(lambda _, **kwargs: None))
|
|
406
|
+
specialize_impl = specialize_impl_cache[0]
|
|
407
|
+
return specialize_impl(arg, specialize_value=specialize)[0]
|
|
336
408
|
|
|
337
409
|
|
|
338
410
|
class KernelInterface(Generic[T]):
|
|
@@ -378,9 +450,17 @@ def create_function_from_signature(sig, kparams, backend):
|
|
|
378
450
|
is_const = 'True' if kp.is_const else 'False'
|
|
379
451
|
specialize = 'False' if kp.do_not_specialize else 'True'
|
|
380
452
|
align = 'False' if kp.do_not_specialize_on_alignment else 'True'
|
|
381
|
-
ret = f"specialize_impl({name},
|
|
453
|
+
ret = f"specialize_impl({name}, {is_const}, {specialize}, {align})"
|
|
382
454
|
if kp.annotation_type:
|
|
383
|
-
|
|
455
|
+
if isinstance(kp.annotation_type, str):
|
|
456
|
+
if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]:
|
|
457
|
+
# we do not specialize non-constexpr floats and bools:
|
|
458
|
+
specialize = False
|
|
459
|
+
if specialize:
|
|
460
|
+
specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
|
|
461
|
+
else:
|
|
462
|
+
# skip runtime specialization:
|
|
463
|
+
specialization.append(f'("{kp.annotation_type}", None)')
|
|
384
464
|
else:
|
|
385
465
|
specialization.append(f"{ret}")
|
|
386
466
|
|
|
@@ -400,9 +480,8 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
|
|
|
400
480
|
if param.default is not inspect.Parameter.empty
|
|
401
481
|
}
|
|
402
482
|
|
|
403
|
-
func_namespace["
|
|
404
|
-
func_namespace["specialize_impl"] = create_specialize_impl()
|
|
405
|
-
func_namespace["specialize_extra"] = backend.get_arg_specialization
|
|
483
|
+
func_namespace["JITCallable"] = JITCallable
|
|
484
|
+
func_namespace["specialize_impl"] = create_specialize_impl(backend.get_arg_specialization)
|
|
406
485
|
|
|
407
486
|
# Execute the function string in func_namespace to create the function
|
|
408
487
|
exec(func_body, func_namespace)
|
|
@@ -411,44 +490,134 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
|
|
|
411
490
|
return func_namespace['dynamic_func']
|
|
412
491
|
|
|
413
492
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
493
|
+
def get_full_name(fn):
|
|
494
|
+
return f"{fn.__module__}.{fn.__qualname__}"
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
class JITCallable:
|
|
498
|
+
|
|
499
|
+
def __init__(self, fn):
|
|
500
|
+
self.fn = fn
|
|
501
|
+
self.signature = inspect.signature(fn)
|
|
502
|
+
try:
|
|
503
|
+
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
|
|
504
|
+
except OSError as e:
|
|
505
|
+
raise ValueError("@jit functions should be defined in a Python file") from e
|
|
506
|
+
self._fn_name = get_full_name(fn)
|
|
507
|
+
self._hash_lock = threading.RLock()
|
|
508
|
+
|
|
509
|
+
# function source code (without decorators)
|
|
510
|
+
src = textwrap.dedent("".join(self.raw_src))
|
|
511
|
+
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
|
|
512
|
+
self._src = src
|
|
513
|
+
self.hash = None
|
|
514
|
+
|
|
515
|
+
# Map of global variables used by the function and any functions it
|
|
516
|
+
# transitively calls, plus their values. The values are collected when
|
|
517
|
+
# the function is first compiled. Then every time we run the function,
|
|
518
|
+
# we check that the values of the globals match what's expected,
|
|
519
|
+
# otherwise we raise an error.
|
|
520
|
+
#
|
|
521
|
+
# Different functions can have different __globals__ maps, so the map
|
|
522
|
+
# key is actually (var name, id(__globals__)), and the map value is
|
|
523
|
+
# (value, __globals__).
|
|
524
|
+
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
|
|
525
|
+
|
|
526
|
+
# reuse docs of wrapped function
|
|
527
|
+
self.__doc__ = fn.__doc__
|
|
528
|
+
self.__name__ = fn.__name__
|
|
529
|
+
self.__qualname__ = fn.__qualname__
|
|
530
|
+
self.__globals__ = fn.__globals__
|
|
531
|
+
self.__module__ = fn.__module__
|
|
532
|
+
|
|
533
|
+
def get_capture_scope(self):
|
|
534
|
+
return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
|
|
535
|
+
|
|
536
|
+
@property
|
|
537
|
+
def cache_key(self):
|
|
538
|
+
# TODO : hash should be attribute of `self`
|
|
539
|
+
with self._hash_lock:
|
|
540
|
+
if self.hash is not None:
|
|
541
|
+
return self.hash
|
|
542
|
+
# Set a placeholder hash to break recursion in case the function
|
|
543
|
+
# transitively calls itself. The full hash is set after.
|
|
544
|
+
self.hash = f"recursion:{self._fn_name}"
|
|
545
|
+
nonlocals = inspect.getclosurevars(self.fn).nonlocals
|
|
546
|
+
dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
|
|
547
|
+
src=self.src)
|
|
548
|
+
dependencies_finder.visit(self.parse())
|
|
549
|
+
self.hash = dependencies_finder.ret + str(self.starting_line_number)
|
|
550
|
+
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
|
|
551
|
+
|
|
552
|
+
from triton.language.core import constexpr
|
|
553
|
+
self.hash += str([(name, val)
|
|
554
|
+
for (name, _), (val, _) in self.used_global_vals.items()
|
|
555
|
+
if isinstance(val, constexpr)])
|
|
556
|
+
self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest()
|
|
557
|
+
return self.hash
|
|
558
|
+
|
|
559
|
+
# we do not parse `src` in the constructor because
|
|
560
|
+
# the user might want to monkey-patch self.src dynamically.
|
|
561
|
+
# Our unit tests do this, for example.
|
|
562
|
+
def parse(self):
|
|
563
|
+
tree = ast.parse(self._src)
|
|
564
|
+
assert isinstance(tree, ast.Module)
|
|
565
|
+
assert len(tree.body) == 1
|
|
566
|
+
assert isinstance(tree.body[0], ast.FunctionDef)
|
|
567
|
+
return tree
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
def type(self):
|
|
571
|
+
from triton.language.core import constexpr_type
|
|
572
|
+
return constexpr_type(self)
|
|
573
|
+
|
|
574
|
+
def _unsafe_update_src(self, new_src):
|
|
575
|
+
"""
|
|
576
|
+
The only method allowed to modify src.
|
|
577
|
+
Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
|
|
578
|
+
|
|
579
|
+
Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None.
|
|
580
|
+
"""
|
|
581
|
+
self.hash = None
|
|
582
|
+
self._src = new_src
|
|
583
|
+
|
|
584
|
+
def _set_src(self):
|
|
585
|
+
raise AttributeError("Cannot set attribute 'src' directly. "
|
|
586
|
+
"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
|
|
587
|
+
"instead.")
|
|
588
|
+
|
|
589
|
+
def _get_src(self):
|
|
590
|
+
return self._src
|
|
591
|
+
|
|
592
|
+
src = property(fget=_get_src, fset=_set_src)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
@dataclass
|
|
596
|
+
class JitFunctionInfo:
|
|
597
|
+
module: ModuleType
|
|
598
|
+
name: str
|
|
599
|
+
jit_function: JITFunction
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def compute_cache_key(kernel_key_cache, specialization, options):
|
|
603
|
+
key = (tuple(specialization), str(options))
|
|
604
|
+
cache_key = kernel_key_cache.get(key, None)
|
|
605
|
+
if cache_key is not None:
|
|
606
|
+
return cache_key
|
|
607
|
+
|
|
608
|
+
cache_key = str(specialization) + str(options)
|
|
609
|
+
kernel_key_cache[key] = cache_key
|
|
610
|
+
return cache_key
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
class JITFunction(JITCallable, KernelInterface[T]):
|
|
614
|
+
|
|
615
|
+
def is_gluon(self):
|
|
616
|
+
return False
|
|
449
617
|
|
|
450
618
|
def _call_hook(
|
|
451
619
|
self,
|
|
620
|
+
hook,
|
|
452
621
|
key,
|
|
453
622
|
signature,
|
|
454
623
|
device,
|
|
@@ -456,26 +625,17 @@ class JITFunction(KernelInterface[T]):
|
|
|
456
625
|
options,
|
|
457
626
|
configs,
|
|
458
627
|
is_warmup,
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
if hook is None:
|
|
463
|
-
return False
|
|
628
|
+
) -> bool | None:
|
|
629
|
+
if not hook:
|
|
630
|
+
return None
|
|
464
631
|
|
|
465
|
-
name = self.fn.
|
|
632
|
+
name = self.fn.__qualname__
|
|
466
633
|
module = self.fn.__module__
|
|
467
634
|
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
|
|
468
635
|
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
|
|
636
|
+
full_name = get_full_name(self.fn)
|
|
469
637
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
def __init__(self, module, name, jit_function):
|
|
473
|
-
self.module = module
|
|
474
|
-
self.name = name
|
|
475
|
-
self.jit_function = jit_function
|
|
476
|
-
pass
|
|
477
|
-
|
|
478
|
-
specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key)
|
|
638
|
+
specialization_data = serialize_specialization_data(full_name, signature, constants, configs[0], options, key)
|
|
479
639
|
|
|
480
640
|
kwargs = {
|
|
481
641
|
'signature': signature,
|
|
@@ -520,10 +680,34 @@ class JITFunction(KernelInterface[T]):
|
|
|
520
680
|
self.compile = compile
|
|
521
681
|
self.ASTSource = ASTSource
|
|
522
682
|
binder = create_function_from_signature(self.signature, self.params, backend)
|
|
523
|
-
return {}, target, backend, binder
|
|
683
|
+
return {}, {}, target, backend, binder
|
|
684
|
+
|
|
685
|
+
def _pack_args(self, backend, kwargs, bound_args, specialization, options):
|
|
686
|
+
# options
|
|
687
|
+
options = backend.parse_options(kwargs)
|
|
688
|
+
# signature
|
|
689
|
+
sigkeys = [x.name for x in self.params]
|
|
690
|
+
sigvals = [x[0] for x in specialization]
|
|
691
|
+
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
|
|
692
|
+
# check arguments
|
|
693
|
+
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
|
|
694
|
+
assert "device" not in kwargs, "device option is deprecated; current device will be used"
|
|
695
|
+
assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
|
|
696
|
+
for k in kwargs:
|
|
697
|
+
if k not in options.__dict__ and k not in sigkeys:
|
|
698
|
+
raise KeyError("Keyword argument %s was specified but unrecognised" % k)
|
|
699
|
+
# constexprs
|
|
700
|
+
constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
|
|
701
|
+
constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
|
|
702
|
+
# attributes
|
|
703
|
+
attrvals = [x[1] for x in specialization]
|
|
704
|
+
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
|
705
|
+
attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
|
|
706
|
+
|
|
707
|
+
return options, signature, constexprs, attrs
|
|
524
708
|
|
|
525
709
|
def run(self, *args, grid, warmup, **kwargs):
|
|
526
|
-
kwargs["debug"] = kwargs.get("debug", self.debug) or
|
|
710
|
+
kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
|
|
527
711
|
|
|
528
712
|
# parse options
|
|
529
713
|
device = driver.active.get_current_device()
|
|
@@ -533,42 +717,22 @@ class JITFunction(KernelInterface[T]):
|
|
|
533
717
|
for hook in self.pre_run_hooks:
|
|
534
718
|
hook(*args, **kwargs)
|
|
535
719
|
|
|
536
|
-
kernel_cache, target, backend, binder = self.device_caches[device]
|
|
720
|
+
kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
|
|
721
|
+
# specialization is list[tuple[str, Any]], where first element of tuple is
|
|
722
|
+
# the type and the second parameter is the 'specialization' value.
|
|
537
723
|
bound_args, specialization, options = binder(*args, **kwargs)
|
|
538
724
|
|
|
539
|
-
|
|
540
|
-
key = str(specialization) + str(options)
|
|
725
|
+
key = compute_cache_key(kernel_key_cache, specialization, options)
|
|
541
726
|
kernel = kernel_cache.get(key, None)
|
|
542
727
|
|
|
543
728
|
# Kernel is not cached; we have to compile.
|
|
544
729
|
if kernel is None:
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
|
|
551
|
-
# check arguments
|
|
552
|
-
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
|
|
553
|
-
assert "device" not in kwargs, "device option is deprecated; current device will be used"
|
|
554
|
-
assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
|
|
555
|
-
for k in kwargs:
|
|
556
|
-
if k not in options.__dict__ and k not in sigkeys:
|
|
557
|
-
raise KeyError("Keyword argument %s was specified but unrecognised" % k)
|
|
558
|
-
# constexprs
|
|
559
|
-
constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
|
|
560
|
-
constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
|
|
561
|
-
# attributes
|
|
562
|
-
attrvals = [x[1] for x in specialization]
|
|
563
|
-
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
|
564
|
-
attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
|
|
565
|
-
if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True):
|
|
730
|
+
options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
|
|
731
|
+
options)
|
|
732
|
+
|
|
733
|
+
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
|
|
734
|
+
if kernel is None:
|
|
566
735
|
return None
|
|
567
|
-
# compile the kernel
|
|
568
|
-
src = self.ASTSource(self, signature, constexprs, attrs)
|
|
569
|
-
kernel = self.compile(src, target=target, options=options.__dict__)
|
|
570
|
-
kernel_cache[key] = kernel
|
|
571
|
-
self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False)
|
|
572
736
|
|
|
573
737
|
# Check that used global values have not changed.
|
|
574
738
|
not_present = object()
|
|
@@ -586,11 +750,12 @@ class JITFunction(KernelInterface[T]):
|
|
|
586
750
|
grid_0 = grid[0]
|
|
587
751
|
grid_1 = grid[1] if grid_size > 1 else 1
|
|
588
752
|
grid_2 = grid[2] if grid_size > 2 else 1
|
|
753
|
+
if hasattr(kernel, "result"):
|
|
754
|
+
kernel = kernel.result()
|
|
589
755
|
# launch kernel
|
|
590
756
|
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
|
|
591
|
-
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
|
|
592
|
-
|
|
593
|
-
*bound_args.values())
|
|
757
|
+
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
|
|
758
|
+
knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
|
|
594
759
|
return kernel
|
|
595
760
|
|
|
596
761
|
def repr(self, _):
|
|
@@ -601,15 +766,12 @@ class JITFunction(KernelInterface[T]):
|
|
|
601
766
|
do_not_specialize = do_not_specialize if do_not_specialize else []
|
|
602
767
|
do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
|
|
603
768
|
|
|
604
|
-
|
|
769
|
+
super().__init__(fn)
|
|
605
770
|
self.module = fn.__module__
|
|
606
771
|
self.version = version
|
|
607
|
-
self.signature = inspect.signature(fn)
|
|
608
772
|
self.do_not_specialize = do_not_specialize
|
|
609
773
|
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
|
|
610
|
-
self.starting_line_number = inspect.getsourcelines(fn)[1]
|
|
611
774
|
self._repr = repr
|
|
612
|
-
self._fn_name = fn.__name__
|
|
613
775
|
self.launch_metadata = launch_metadata
|
|
614
776
|
|
|
615
777
|
self.params = []
|
|
@@ -618,24 +780,8 @@ class JITFunction(KernelInterface[T]):
|
|
|
618
780
|
dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
|
|
619
781
|
self.params.append(KernelParam(i, param, dns, dns_oa))
|
|
620
782
|
|
|
621
|
-
# function source code (without decorators)
|
|
622
|
-
src = textwrap.dedent(inspect.getsource(fn))
|
|
623
|
-
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
|
|
624
|
-
self._unsafe_update_src(src)
|
|
625
783
|
# cache of just-in-time compiled kernels
|
|
626
784
|
self.device_caches = defaultdict(self.create_binder)
|
|
627
|
-
self.hash = None
|
|
628
|
-
|
|
629
|
-
# Map of global variables used by the function and any functions it
|
|
630
|
-
# transitively calls, plus their values. The values are collected when
|
|
631
|
-
# the function is first compiled. Then every time we run the function,
|
|
632
|
-
# we check that the values of the globals match what's expected,
|
|
633
|
-
# otherwise we raise an error.
|
|
634
|
-
#
|
|
635
|
-
# Different functions can have different __globals__ maps, so the map
|
|
636
|
-
# key is actually (var name, id(__globals__)), and the map value is
|
|
637
|
-
# (value, __globals__).
|
|
638
|
-
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
|
|
639
785
|
|
|
640
786
|
# JITFunction can be instantiated as kernel
|
|
641
787
|
# when called with a grid using __getitem__
|
|
@@ -651,37 +797,20 @@ class JITFunction(KernelInterface[T]):
|
|
|
651
797
|
# Hooks that will be called prior to executing "run"
|
|
652
798
|
self.pre_run_hooks = []
|
|
653
799
|
|
|
654
|
-
# reuse docs of wrapped function
|
|
655
|
-
self.__doc__ = fn.__doc__
|
|
656
|
-
self.__name__ = fn.__name__
|
|
657
|
-
self.__globals__ = fn.__globals__
|
|
658
|
-
self.__module__ = fn.__module__
|
|
659
|
-
|
|
660
|
-
@property
|
|
661
|
-
def cache_key(self):
|
|
662
|
-
# TODO : hash should be attribute of `self`
|
|
663
|
-
if self.hash is None:
|
|
664
|
-
dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src)
|
|
665
|
-
dependencies_finder.visit(self.parse())
|
|
666
|
-
self.hash = dependencies_finder.ret + str(self.starting_line_number)
|
|
667
|
-
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
|
|
668
|
-
return self.hash
|
|
669
|
-
|
|
670
800
|
def warmup(self, *args, grid, **kwargs):
|
|
671
801
|
return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
|
|
672
802
|
|
|
673
803
|
def preload(self, specialization_data):
|
|
674
|
-
from ..compiler import compile, ASTSource
|
|
675
804
|
import json
|
|
676
805
|
import triton.language as tl
|
|
677
806
|
device = driver.active.get_current_device()
|
|
678
807
|
deserialized_obj = json.loads(specialization_data)
|
|
679
|
-
if deserialized_obj['name'] != self.
|
|
808
|
+
if deserialized_obj['name'] != self._fn_name:
|
|
680
809
|
raise RuntimeError(
|
|
681
|
-
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.
|
|
810
|
+
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
|
|
682
811
|
constant_keys = map(tuple, deserialized_obj['constant_keys'])
|
|
683
812
|
constant_vals = deserialized_obj['constant_vals']
|
|
684
|
-
|
|
813
|
+
constexprs = {
|
|
685
814
|
key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
|
|
686
815
|
for key, value in zip(constant_keys, constant_vals)
|
|
687
816
|
}
|
|
@@ -689,47 +818,57 @@ class JITFunction(KernelInterface[T]):
|
|
|
689
818
|
attrs_vals = deserialized_obj['attrs_vals']
|
|
690
819
|
attrs = dict(zip(attrs_keys, attrs_vals))
|
|
691
820
|
signature = dict(deserialized_obj['signature'].items())
|
|
692
|
-
src = ASTSource(self, signature, constants, attrs)
|
|
693
821
|
options = {
|
|
694
822
|
key: tuple(value) if isinstance(value, list) else value
|
|
695
823
|
for key, value in deserialized_obj['options'].items()
|
|
696
824
|
}
|
|
697
825
|
key = deserialized_obj['key']
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
return
|
|
826
|
+
_, _, _, backend, _ = self.device_caches[device]
|
|
827
|
+
options = backend.parse_options(options)
|
|
828
|
+
return self._do_compile(
|
|
829
|
+
key,
|
|
830
|
+
signature,
|
|
831
|
+
device,
|
|
832
|
+
constexprs,
|
|
833
|
+
options,
|
|
834
|
+
attrs,
|
|
835
|
+
warmup=True,
|
|
836
|
+
)
|
|
701
837
|
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
# Our unit tests do this, for example.
|
|
705
|
-
def parse(self):
|
|
706
|
-
tree = ast.parse(self.src)
|
|
707
|
-
assert isinstance(tree, ast.Module)
|
|
708
|
-
assert len(tree.body) == 1
|
|
709
|
-
assert isinstance(tree.body[0], ast.FunctionDef)
|
|
710
|
-
return tree
|
|
838
|
+
def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup):
|
|
839
|
+
kernel_cache, _, target, backend, _ = self.device_caches[device]
|
|
711
840
|
|
|
712
|
-
|
|
713
|
-
|
|
841
|
+
if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs], warmup):
|
|
842
|
+
return None
|
|
843
|
+
src = self.ASTSource(self, signature, constexprs, attrs)
|
|
714
844
|
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
if name == "src":
|
|
718
|
-
raise AttributeError(f"Cannot set attribute '{name}' directly. "
|
|
719
|
-
f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
|
|
720
|
-
f"instead.")
|
|
721
|
-
super(JITFunction, self).__setattr__(name, value)
|
|
845
|
+
async_mode = _async_compile.active_mode.get()
|
|
846
|
+
if async_mode is not None:
|
|
722
847
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
848
|
+
env_vars = get_cache_invalidating_env_vars()
|
|
849
|
+
cache_key = get_cache_key(src, backend, options, env_vars)
|
|
850
|
+
|
|
851
|
+
def async_compile():
|
|
852
|
+
return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars)
|
|
853
|
+
|
|
854
|
+
def finalize_compile(kernel):
|
|
855
|
+
kernel_cache[key] = kernel
|
|
856
|
+
self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options,
|
|
857
|
+
[attrs], warmup)
|
|
858
|
+
|
|
859
|
+
kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
|
|
860
|
+
else:
|
|
861
|
+
kernel = self.compile(src, target=target, options=options.__dict__)
|
|
862
|
+
kernel_cache[key] = kernel
|
|
863
|
+
self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
|
|
864
|
+
warmup)
|
|
865
|
+
return kernel
|
|
866
|
+
|
|
867
|
+
def __call__(self, *args, **kwargs):
|
|
868
|
+
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
|
730
869
|
|
|
731
870
|
def __repr__(self):
|
|
732
|
-
return f"JITFunction({self.module}:{self.fn.
|
|
871
|
+
return f"JITFunction({self.module}:{self.fn.__qualname__})"
|
|
733
872
|
|
|
734
873
|
|
|
735
874
|
# -----------------------------------------------------------------------------
|
|
@@ -748,8 +887,8 @@ def jit(
|
|
|
748
887
|
version=None,
|
|
749
888
|
repr: Optional[Callable] = None,
|
|
750
889
|
launch_metadata: Optional[Callable] = None,
|
|
751
|
-
do_not_specialize: Optional[Iterable[int]] = None,
|
|
752
|
-
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
|
|
890
|
+
do_not_specialize: Optional[Iterable[int | str]] = None,
|
|
891
|
+
do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
|
|
753
892
|
debug: Optional[bool] = None,
|
|
754
893
|
noinline: Optional[bool] = None,
|
|
755
894
|
) -> Callable[[T], JITFunction[T]]:
|
|
@@ -762,8 +901,8 @@ def jit(
|
|
|
762
901
|
version=None,
|
|
763
902
|
repr: Optional[Callable] = None,
|
|
764
903
|
launch_metadata: Optional[Callable] = None,
|
|
765
|
-
do_not_specialize: Optional[Iterable[int]] = None,
|
|
766
|
-
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
|
|
904
|
+
do_not_specialize: Optional[Iterable[int | str]] = None,
|
|
905
|
+
do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
|
|
767
906
|
debug: Optional[bool] = None,
|
|
768
907
|
noinline: Optional[bool] = None,
|
|
769
908
|
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
|
@@ -787,7 +926,7 @@ def jit(
|
|
|
787
926
|
|
|
788
927
|
def decorator(fn: T) -> JITFunction[T]:
|
|
789
928
|
assert callable(fn)
|
|
790
|
-
if
|
|
929
|
+
if knobs.runtime.interpret:
|
|
791
930
|
from .interpreter import InterpretedFunction
|
|
792
931
|
return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
|
|
793
932
|
do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,
|
|
@@ -828,8 +967,17 @@ class MockTensor:
|
|
|
828
967
|
return MockTensor(arg)
|
|
829
968
|
return arg
|
|
830
969
|
|
|
831
|
-
def __init__(self, dtype):
|
|
970
|
+
def __init__(self, dtype, shape=None):
|
|
971
|
+
if shape is None:
|
|
972
|
+
shape = [1]
|
|
832
973
|
self.dtype = dtype
|
|
974
|
+
self.shape = shape
|
|
975
|
+
|
|
976
|
+
def stride(self):
|
|
977
|
+
strides = [1]
|
|
978
|
+
for size in self.shape[1:]:
|
|
979
|
+
strides.append(strides[-1] * size)
|
|
980
|
+
return tuple(reversed(strides))
|
|
833
981
|
|
|
834
982
|
@staticmethod
|
|
835
983
|
def data_ptr():
|
|
@@ -894,17 +1042,66 @@ def reinterpret(tensor, dtype):
|
|
|
894
1042
|
|
|
895
1043
|
def get_jit_fn_file_line(fn):
|
|
896
1044
|
base_fn = fn
|
|
897
|
-
while not isinstance(base_fn,
|
|
1045
|
+
while not isinstance(base_fn, JITCallable):
|
|
898
1046
|
base_fn = base_fn.fn
|
|
899
1047
|
file_name = base_fn.fn.__code__.co_filename
|
|
900
|
-
|
|
1048
|
+
begin_line = base_fn.starting_line_number
|
|
901
1049
|
# Match the following pattern:
|
|
902
1050
|
# @triton.autotune(...) <- foo.__code__.co_firstlineno
|
|
903
1051
|
# @triton.heuristics(...)
|
|
904
1052
|
# @triton.jit
|
|
905
1053
|
# def foo(...): <- this line is the first line
|
|
906
|
-
for idx, line in enumerate(
|
|
1054
|
+
for idx, line in enumerate(base_fn.raw_src):
|
|
907
1055
|
if line.strip().startswith("def "):
|
|
908
1056
|
begin_line += idx
|
|
909
1057
|
break
|
|
910
1058
|
return file_name, begin_line
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
class BoundConstexprFunction(JITCallable):
|
|
1062
|
+
|
|
1063
|
+
def __init__(self, instance, fn):
|
|
1064
|
+
self.__self__ = instance
|
|
1065
|
+
self.__func__ = fn
|
|
1066
|
+
|
|
1067
|
+
def __call__(self, *args, **kwargs):
|
|
1068
|
+
return self.__func__(self.__self__, *args, **kwargs)
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
class ConstexprFunction(JITCallable):
|
|
1072
|
+
|
|
1073
|
+
def __init__(self, fn):
|
|
1074
|
+
super().__init__(fn)
|
|
1075
|
+
|
|
1076
|
+
def __get__(self, obj, objclass):
|
|
1077
|
+
# Create a bound function to support constexpr_function methods
|
|
1078
|
+
if obj is not None:
|
|
1079
|
+
return BoundConstexprFunction(obj, self)
|
|
1080
|
+
return self
|
|
1081
|
+
|
|
1082
|
+
def __call__(self, *args, _semantic=None, **kwargs):
|
|
1083
|
+
from triton.language.core import _unwrap_if_constexpr, constexpr
|
|
1084
|
+
# de-constexpr arguments and discard the _semantic keyword argument:
|
|
1085
|
+
args = [_unwrap_if_constexpr(x) for x in args]
|
|
1086
|
+
kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
|
|
1087
|
+
|
|
1088
|
+
# call the raw Python function f:
|
|
1089
|
+
res = self.fn(*args, **kwargs)
|
|
1090
|
+
|
|
1091
|
+
if _semantic is None:
|
|
1092
|
+
# Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function
|
|
1093
|
+
return res
|
|
1094
|
+
|
|
1095
|
+
# convert result back to a Triton constexpr:
|
|
1096
|
+
if knobs.runtime.interpret:
|
|
1097
|
+
return res # No constexpr in interpreter
|
|
1098
|
+
return constexpr(res)
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def constexpr_function(fn):
|
|
1102
|
+
"""
|
|
1103
|
+
Wraps an arbitrary Python function so that it can be called at
|
|
1104
|
+
compile-time on constexpr arguments in a Triton function and
|
|
1105
|
+
returns a constexpr result.
|
|
1106
|
+
"""
|
|
1107
|
+
return ConstexprFunction(fn)
|