triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1414 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import ast
|
|
3
|
+
import textwrap
|
|
4
|
+
import inspect
|
|
5
|
+
from typing import Tuple, List, Dict, Callable
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
import triton
|
|
11
|
+
import triton.language as tl
|
|
12
|
+
import dataclasses
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
from triton.language.semantic import TritonSemantic
|
|
16
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
17
|
+
from .errors import InterpreterError
|
|
18
|
+
from functools import partial
|
|
19
|
+
from .._C.libtriton import interpreter as _interpreter
|
|
20
|
+
from .._C.libtriton import ir as _ir
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class TensorHandle:
|
|
25
|
+
'''
|
|
26
|
+
data: numpy array
|
|
27
|
+
dtype: triton type, either pointer_type or scalar_type.
|
|
28
|
+
we don't store block_type here because the shape information is already available in the data field
|
|
29
|
+
attr: a dictionary of attributes
|
|
30
|
+
'''
|
|
31
|
+
data: np.array
|
|
32
|
+
dtype: tl.dtype
|
|
33
|
+
attr: Dict = dataclasses.field(default_factory=dict)
|
|
34
|
+
|
|
35
|
+
def __bool__(self):
|
|
36
|
+
return bool(self.data.all())
|
|
37
|
+
|
|
38
|
+
def get_element_ty(self):
|
|
39
|
+
dtype = self.dtype
|
|
40
|
+
while hasattr(dtype, "element_ty"):
|
|
41
|
+
dtype = dtype.element_ty
|
|
42
|
+
return dtype
|
|
43
|
+
|
|
44
|
+
def clone(self):
|
|
45
|
+
return TensorHandle(self.data.copy(), self.dtype)
|
|
46
|
+
|
|
47
|
+
def set_attr(self, key, value):
|
|
48
|
+
self.attr[key] = value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class BlockPointerHandle:
|
|
52
|
+
|
|
53
|
+
def __init__(self, base, shape, strides, offsets, block_shape, order):
|
|
54
|
+
self.base = base
|
|
55
|
+
self.shape = shape
|
|
56
|
+
self.strides = strides
|
|
57
|
+
self.offsets = offsets
|
|
58
|
+
self.block_shape = block_shape
|
|
59
|
+
self.order = order
|
|
60
|
+
|
|
61
|
+
def materialize_pointers(self, boundary_check):
|
|
62
|
+
dtype_tt = self.base.get_element_ty()
|
|
63
|
+
n_bytes = dtype_tt.primitive_bitwidth // 8
|
|
64
|
+
ptrs = np.broadcast_to(self.base.data, self.block_shape)
|
|
65
|
+
masks = np.ones(self.block_shape, dtype=bool)
|
|
66
|
+
for dim in range(len(self.block_shape)):
|
|
67
|
+
bcast_dims = [1] * len(self.block_shape)
|
|
68
|
+
bcast_dims[dim] = self.block_shape[dim]
|
|
69
|
+
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
|
|
70
|
+
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
|
|
71
|
+
if dim in boundary_check:
|
|
72
|
+
masks = masks & (off < self.shape[dim].data) & (off >= 0)
|
|
73
|
+
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
|
|
74
|
+
return ptrs, masks
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TensorDescHandle:
|
|
78
|
+
|
|
79
|
+
def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
|
|
80
|
+
block_shape: List[int], padding):
|
|
81
|
+
self.base = base
|
|
82
|
+
self.ndim = len(shape)
|
|
83
|
+
self.shape = shape
|
|
84
|
+
self.strides = strides
|
|
85
|
+
self.block_shape = block_shape
|
|
86
|
+
self.padding = padding
|
|
87
|
+
|
|
88
|
+
def validate(self):
|
|
89
|
+
assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
|
|
90
|
+
assert len(self.strides) == self.ndim
|
|
91
|
+
assert len(self.block_shape) == self.ndim
|
|
92
|
+
assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
|
|
93
|
+
|
|
94
|
+
for stride in self.strides[:-1]:
|
|
95
|
+
assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
|
|
96
|
+
assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
|
|
97
|
+
|
|
98
|
+
def materialize_pointers(self, offsets: List[TensorHandle]):
|
|
99
|
+
assert len(offsets) == self.ndim
|
|
100
|
+
scalar_ty = self.base.dtype.element_ty
|
|
101
|
+
itemsize = scalar_ty.primitive_bitwidth // 8
|
|
102
|
+
assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned"
|
|
103
|
+
|
|
104
|
+
ptrs = np.broadcast_to(self.base.data, self.block_shape)
|
|
105
|
+
masks = np.ones(self.block_shape, dtype=bool)
|
|
106
|
+
for dim in range(len(self.block_shape)):
|
|
107
|
+
bcast_dims = [1] * len(self.block_shape)
|
|
108
|
+
bcast_dims[dim] = self.block_shape[dim]
|
|
109
|
+
off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
|
|
110
|
+
ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
|
|
111
|
+
masks = masks & (0 <= off) & (off < self.shape[dim].data)
|
|
112
|
+
assert ptrs.dtype == np.uint64
|
|
113
|
+
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
|
|
114
|
+
return ptrs, masks
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass(frozen=True)
|
|
118
|
+
class InterpreterOptions:
|
|
119
|
+
extern_libs: dict = None
|
|
120
|
+
debug: bool = False
|
|
121
|
+
sanitize_overflow: bool = True
|
|
122
|
+
arch: str = None
|
|
123
|
+
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
|
|
124
|
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
125
|
+
default_dot_input_precision: str = "tf32"
|
|
126
|
+
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
|
|
127
|
+
max_num_imprecise_acc_default: int = 0
|
|
128
|
+
backend_name: str = "interpreter"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _get_signed_np_dtype(dtype):
|
|
132
|
+
if dtype == np.uint8:
|
|
133
|
+
return np.int8
|
|
134
|
+
if dtype == np.uint16:
|
|
135
|
+
return np.int16
|
|
136
|
+
if dtype == np.uint32:
|
|
137
|
+
return np.int32
|
|
138
|
+
if dtype == np.uint64:
|
|
139
|
+
return np.int64
|
|
140
|
+
return dtype
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _get_np_dtype(tt_dtype):
|
|
144
|
+
if isinstance(tt_dtype, tl.pointer_type):
|
|
145
|
+
return np.dtype(np.uint64)
|
|
146
|
+
np_types = {
|
|
147
|
+
tl.int1: np.dtype(bool),
|
|
148
|
+
tl.float16: np.dtype(np.float16),
|
|
149
|
+
tl.float32: np.dtype(np.float32),
|
|
150
|
+
tl.float64: np.dtype(np.float64),
|
|
151
|
+
tl.int8: np.dtype(np.int8),
|
|
152
|
+
tl.uint8: np.dtype(np.uint8),
|
|
153
|
+
tl.int16: np.dtype(np.int16),
|
|
154
|
+
tl.uint16: np.dtype(np.uint16),
|
|
155
|
+
tl.int32: np.dtype(np.int32),
|
|
156
|
+
tl.uint32: np.dtype(np.uint32),
|
|
157
|
+
tl.int64: np.dtype(np.int64),
|
|
158
|
+
tl.uint64: np.dtype(np.uint64),
|
|
159
|
+
# bfloat16 types are stored as uint16
|
|
160
|
+
tl.bfloat16: np.dtype(np.uint16),
|
|
161
|
+
# float8 types are stored as uint8
|
|
162
|
+
tl.float8e5: np.dtype(np.uint8),
|
|
163
|
+
tl.float8e5b16: np.dtype(np.uint8),
|
|
164
|
+
tl.float8e4nv: np.dtype(np.uint8),
|
|
165
|
+
tl.float8e4b8: np.dtype(np.uint8),
|
|
166
|
+
tl.float8e4b15: np.dtype(np.uint8),
|
|
167
|
+
}
|
|
168
|
+
if isinstance(tt_dtype, tl.block_type):
|
|
169
|
+
if isinstance(tt_dtype.element_ty, tl.pointer_type):
|
|
170
|
+
return np.dtype(np.uint64)
|
|
171
|
+
return np_types[tt_dtype.element_ty]
|
|
172
|
+
return np_types[tt_dtype]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _convert_float(input, input_dtype, output_dtype, rounding_mode):
|
|
176
|
+
input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}")
|
|
177
|
+
output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}")
|
|
178
|
+
input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype)
|
|
179
|
+
sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01
|
|
180
|
+
input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1
|
|
181
|
+
output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1
|
|
182
|
+
significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1)
|
|
183
|
+
bias_input = input_dtype.exponent_bias
|
|
184
|
+
bias_output = output_dtype.exponent_bias
|
|
185
|
+
exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
|
|
186
|
+
subnormal_index = exponent == 0
|
|
187
|
+
if np.any(subnormal_index):
|
|
188
|
+
# Credit to Phil: phil@openai.com
|
|
189
|
+
# subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn))
|
|
190
|
+
# where m0, m1, ..., mn are the 1-bit of the mantissa
|
|
191
|
+
# convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0))
|
|
192
|
+
bit_pos = np.zeros_like(input_bin, dtype=np.int32)
|
|
193
|
+
# Find the most significant bit of the mantissa in the significand
|
|
194
|
+
for i in range(input_dtype.fp_mantissa_width):
|
|
195
|
+
bit_index = ((significand >> i) & 0x01)
|
|
196
|
+
# pos should be >= 1
|
|
197
|
+
bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i
|
|
198
|
+
zero_significand_index = significand == 0
|
|
199
|
+
exponent[subnormal_index] = 1 - bit_pos[subnormal_index]
|
|
200
|
+
# 0 significand and subnormal should be treated as 0
|
|
201
|
+
exponent[zero_significand_index & subnormal_index] = bias_input - bias_output
|
|
202
|
+
significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & (
|
|
203
|
+
(1 << input_dtype.fp_mantissa_width) - 1)
|
|
204
|
+
# Prevent overflow and underflow
|
|
205
|
+
exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1))
|
|
206
|
+
exponent_output = exponent_output.astype(output_unint_dtype)
|
|
207
|
+
sign_output = sign.astype(output_unint_dtype)
|
|
208
|
+
if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast
|
|
209
|
+
significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & (
|
|
210
|
+
(1 << output_dtype.fp_mantissa_width) - 1)
|
|
211
|
+
if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even
|
|
212
|
+
# find the cut-off bit
|
|
213
|
+
cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1))
|
|
214
|
+
significand_output = significand_output + (cut_off > 0)
|
|
215
|
+
significand_output = significand_output.astype(output_unint_dtype)
|
|
216
|
+
else: # Upcast
|
|
217
|
+
significand_output = (significand.astype(output_unint_dtype) <<
|
|
218
|
+
(output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & (
|
|
219
|
+
(1 << output_dtype.fp_mantissa_width) - 1)
|
|
220
|
+
subnormal_index = exponent_output == 0
|
|
221
|
+
if np.any(subnormal_index): # underflow
|
|
222
|
+
# normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn))
|
|
223
|
+
# where m0, m1, ..., mn are the 1-bit of the mantissa
|
|
224
|
+
# shift = (1 - exp_bias_output) - (exp - exp_bias_input)
|
|
225
|
+
# convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift))
|
|
226
|
+
exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
|
|
227
|
+
non_zero_exponent_index = exponent != 0
|
|
228
|
+
# If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa
|
|
229
|
+
subnormal_index = subnormal_index & non_zero_exponent_index
|
|
230
|
+
shift = np.zeros_like(input_bin, dtype=np.int32)
|
|
231
|
+
shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input)
|
|
232
|
+
significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | (
|
|
233
|
+
1 << (output_dtype.fp_mantissa_width - shift[subnormal_index]))
|
|
234
|
+
output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | (
|
|
235
|
+
exponent_output << output_dtype.fp_mantissa_width) | significand_output
|
|
236
|
+
return output.reshape(input.shape)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _erf(x):
|
|
240
|
+
# Numpy does not support erf
|
|
241
|
+
return math.erf(x)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _umulhi_64(a, b):
|
|
245
|
+
# Numpy does not support 128-bit multiplication
|
|
246
|
+
# So we have to implement it manually
|
|
247
|
+
return (int(a) * int(b)) >> 64
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32])
|
|
251
|
+
np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64])
|
|
252
|
+
np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class ExtraFunctions:
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
|
|
259
|
+
return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class InterpreterBuilder:
|
|
263
|
+
ir_sem_to_interpreter_sem = {
|
|
264
|
+
_ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE,
|
|
265
|
+
_ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE,
|
|
266
|
+
_ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED,
|
|
267
|
+
_ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE,
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
ir_rmw_op_to_interpreter_rmw_op = {
|
|
271
|
+
_ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD,
|
|
272
|
+
_ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD,
|
|
273
|
+
_ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN,
|
|
274
|
+
_ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN,
|
|
275
|
+
_ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX,
|
|
276
|
+
_ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX,
|
|
277
|
+
_ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND,
|
|
278
|
+
_ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR,
|
|
279
|
+
_ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR,
|
|
280
|
+
_ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG,
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
def __init__(self) -> None:
|
|
284
|
+
self.arch = None
|
|
285
|
+
self.options = InterpreterOptions()
|
|
286
|
+
self.codegen_fns = {}
|
|
287
|
+
self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types
|
|
288
|
+
self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1)
|
|
289
|
+
|
|
290
|
+
def set_grid_idx(self, x, y, z):
|
|
291
|
+
if not x < self.grid_dim[0]:
|
|
292
|
+
raise ValueError("x >= grid_dim[0]")
|
|
293
|
+
if not y < self.grid_dim[1]:
|
|
294
|
+
raise ValueError("y >= grid_dim[1]")
|
|
295
|
+
if not z < self.grid_dim[2]:
|
|
296
|
+
raise ValueError("z >= grid_dim[2]")
|
|
297
|
+
self.grid_idx = (x, y, z)
|
|
298
|
+
|
|
299
|
+
def set_grid_dim(self, nx, ny, nz):
|
|
300
|
+
self.grid_dim = (nx, ny, nz)
|
|
301
|
+
|
|
302
|
+
# constants
|
|
303
|
+
|
|
304
|
+
def get_half_ty(self):
|
|
305
|
+
return tl.float16
|
|
306
|
+
|
|
307
|
+
def get_bf16_ty(self):
|
|
308
|
+
return tl.bfloat16
|
|
309
|
+
|
|
310
|
+
def get_float_ty(self):
|
|
311
|
+
return tl.float32
|
|
312
|
+
|
|
313
|
+
def get_double_ty(self):
|
|
314
|
+
return tl.float64
|
|
315
|
+
|
|
316
|
+
def get_int1_ty(self):
|
|
317
|
+
return tl.int1
|
|
318
|
+
|
|
319
|
+
def get_int8_ty(self):
|
|
320
|
+
return tl.int8
|
|
321
|
+
|
|
322
|
+
def get_uint8_ty(self):
|
|
323
|
+
return tl.uint8
|
|
324
|
+
|
|
325
|
+
def get_int16_ty(self):
|
|
326
|
+
return tl.int16
|
|
327
|
+
|
|
328
|
+
def get_uint16_ty(self):
|
|
329
|
+
return tl.uint16
|
|
330
|
+
|
|
331
|
+
def get_int32_ty(self):
|
|
332
|
+
return tl.int32
|
|
333
|
+
|
|
334
|
+
def get_uint32_ty(self):
|
|
335
|
+
return tl.uint32
|
|
336
|
+
|
|
337
|
+
def get_int64_ty(self):
|
|
338
|
+
return tl.int64
|
|
339
|
+
|
|
340
|
+
def get_uint64_ty(self):
|
|
341
|
+
return tl.uint64
|
|
342
|
+
|
|
343
|
+
def get_fp8e4nv_ty(self):
|
|
344
|
+
return tl.float8e4nv
|
|
345
|
+
|
|
346
|
+
def get_fp8e4b15_ty(self):
|
|
347
|
+
return tl.float8e4b15
|
|
348
|
+
|
|
349
|
+
def get_fp8e4b8_ty(self):
|
|
350
|
+
return tl.float8e4b8
|
|
351
|
+
|
|
352
|
+
def get_fp8e5_ty(self):
|
|
353
|
+
return tl.float8e5
|
|
354
|
+
|
|
355
|
+
def get_fp8e5b16_ty(self):
|
|
356
|
+
return tl.float8e5b16
|
|
357
|
+
|
|
358
|
+
def get_ptr_ty(self, elt_ty, addr_space):
|
|
359
|
+
return tl.pointer_type(elt_ty, addr_space)
|
|
360
|
+
|
|
361
|
+
def get_block_ty(self, dtype, shape):
|
|
362
|
+
return tl.block_type(dtype, shape)
|
|
363
|
+
|
|
364
|
+
def get_int1(self, value):
|
|
365
|
+
return TensorHandle(np.array([value], dtype=np.bool_), tl.int1)
|
|
366
|
+
|
|
367
|
+
def get_uint8(self, value):
|
|
368
|
+
return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8)
|
|
369
|
+
|
|
370
|
+
def get_int8(self, value):
|
|
371
|
+
return TensorHandle(np.array([value], dtype=np.int8), tl.int8)
|
|
372
|
+
|
|
373
|
+
def get_uint16(self, value):
|
|
374
|
+
return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16)
|
|
375
|
+
|
|
376
|
+
def get_int16(self, value):
|
|
377
|
+
return TensorHandle(np.array([value], dtype=np.int16), tl.int16)
|
|
378
|
+
|
|
379
|
+
def get_uint32(self, value):
|
|
380
|
+
return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32)
|
|
381
|
+
|
|
382
|
+
def get_int32(self, value):
|
|
383
|
+
return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
|
|
384
|
+
|
|
385
|
+
def get_uint64(self, value):
|
|
386
|
+
return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64)
|
|
387
|
+
|
|
388
|
+
def get_int64(self, value):
|
|
389
|
+
return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
|
|
390
|
+
|
|
391
|
+
def get_fp16(self, value):
|
|
392
|
+
return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
|
|
393
|
+
|
|
394
|
+
def get_fp32(self, value):
|
|
395
|
+
return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
|
|
396
|
+
|
|
397
|
+
def get_fp64(self, value):
|
|
398
|
+
return TensorHandle(np.array([value], dtype=np.float64), tl.float64)
|
|
399
|
+
|
|
400
|
+
def get_null_value(self, type):
|
|
401
|
+
return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type)
|
|
402
|
+
|
|
403
|
+
# programming model
|
|
404
|
+
def create_get_program_id(self, axis):
|
|
405
|
+
if self.grid_idx is None:
|
|
406
|
+
raise ValueError("grid_idx is None")
|
|
407
|
+
return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
|
|
408
|
+
|
|
409
|
+
def create_get_num_programs(self, axis):
|
|
410
|
+
return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
|
|
411
|
+
|
|
412
|
+
# memory ops
|
|
413
|
+
def create_load(self, ptr, _0, _1, is_volatile):
|
|
414
|
+
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
|
|
415
|
+
other = None
|
|
416
|
+
return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
|
|
417
|
+
|
|
418
|
+
def create_store(self, ptr, val, _0, _1):
|
|
419
|
+
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
|
|
420
|
+
return self.create_masked_store(ptr, val, mask, None, None)
|
|
421
|
+
|
|
422
|
+
def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
|
|
423
|
+
dtype_tt = ptrs.get_element_ty()
|
|
424
|
+
dtype_np = _get_np_dtype(dtype_tt)
|
|
425
|
+
if other is None:
|
|
426
|
+
other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
|
|
427
|
+
ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
|
|
428
|
+
return TensorHandle(ret, dtype_tt)
|
|
429
|
+
|
|
430
|
+
def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
|
|
431
|
+
return _interpreter.store(ptrs.data, value.data, mask.data)
|
|
432
|
+
|
|
433
|
+
# casting ops
|
|
434
|
+
def cast_impl(self, src, dst_type):
|
|
435
|
+
src_element_type = src.dtype.scalar
|
|
436
|
+
dst_element_type = dst_type.scalar
|
|
437
|
+
if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \
|
|
438
|
+
(src_element_type == tl.float32 and dst_element_type == tl.bfloat16):
|
|
439
|
+
data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type))
|
|
440
|
+
return TensorHandle(data, dst_type.scalar)
|
|
441
|
+
else:
|
|
442
|
+
return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar)
|
|
443
|
+
|
|
444
|
+
create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
|
445
|
+
create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
|
446
|
+
create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
|
447
|
+
create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
|
448
|
+
create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
|
449
|
+
create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
|
450
|
+
create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
|
|
451
|
+
|
|
452
|
+
def create_fp_to_fp(self, src, dst_type, rounding_mode):
|
|
453
|
+
src_element_type = src.dtype.scalar
|
|
454
|
+
dst_element_type = dst_type.scalar
|
|
455
|
+
data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type))
|
|
456
|
+
return TensorHandle(data, dst_type.scalar)
|
|
457
|
+
|
|
458
|
+
def create_bitcast(self, src, dst_type):
|
|
459
|
+
return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar)
|
|
460
|
+
|
|
461
|
+
# binary operators
|
|
462
|
+
def binary_op(self, lhs, rhs, op):
|
|
463
|
+
return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar)
|
|
464
|
+
|
|
465
|
+
create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
|
|
466
|
+
create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
|
467
|
+
create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
|
|
468
|
+
create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
|
|
469
|
+
create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
|
|
470
|
+
create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
|
471
|
+
create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
|
|
472
|
+
create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
|
|
473
|
+
create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
|
|
474
|
+
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
|
475
|
+
create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
|
|
476
|
+
create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
|
|
477
|
+
create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
|
|
478
|
+
create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
|
|
479
|
+
create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
|
|
480
|
+
create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
|
|
481
|
+
create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
|
482
|
+
create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
|
483
|
+
create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
|
484
|
+
create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
|
485
|
+
create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
|
486
|
+
create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
|
487
|
+
create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
|
488
|
+
create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
|
489
|
+
create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
|
490
|
+
create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
|
491
|
+
create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
|
492
|
+
create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
|
493
|
+
create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
|
494
|
+
create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
|
495
|
+
create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
|
496
|
+
create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
|
497
|
+
create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
|
498
|
+
create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
|
499
|
+
create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
|
500
|
+
create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
|
501
|
+
create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
|
502
|
+
create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
|
503
|
+
create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
|
504
|
+
create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
|
505
|
+
create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
|
506
|
+
create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
|
507
|
+
create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
|
508
|
+
create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
|
509
|
+
create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
|
510
|
+
create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
|
511
|
+
create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
|
|
512
|
+
create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
|
|
513
|
+
create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
|
|
514
|
+
create_int_to_ptr = create_bitcast
|
|
515
|
+
create_ptr_to_int = create_bitcast
|
|
516
|
+
|
|
517
|
+
def create_idiv(self, lhs, rhs):
|
|
518
|
+
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
|
519
|
+
# through to //, so we have to use a nonstandard expression to get a
|
|
520
|
+
# reference result for //.
|
|
521
|
+
return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar)
|
|
522
|
+
|
|
523
|
+
def create_ashr(self, lhs, rhs):
|
|
524
|
+
# Triton's rshift operator depends on the signedness of the left operand
|
|
525
|
+
lhs_dtype = _get_signed_np_dtype(lhs.data.dtype)
|
|
526
|
+
rhs_dtype = _get_signed_np_dtype(rhs.data.dtype)
|
|
527
|
+
lhs.data = lhs.data.astype(lhs_dtype)
|
|
528
|
+
rhs.data = rhs.data.astype(rhs_dtype)
|
|
529
|
+
return self.binary_op(lhs, rhs, np.right_shift)
|
|
530
|
+
|
|
531
|
+
def create_umulhi(self, lhs, rhs):
|
|
532
|
+
dtype = lhs.data.dtype
|
|
533
|
+
if dtype == np.int64 or dtype == np.uint64:
|
|
534
|
+
return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar)
|
|
535
|
+
else:
|
|
536
|
+
compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}")
|
|
537
|
+
lhs_data = lhs.data.astype(compute_dtype)
|
|
538
|
+
rhs_data = rhs.data.astype(compute_dtype)
|
|
539
|
+
ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8)
|
|
540
|
+
return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar)
|
|
541
|
+
|
|
542
|
+
# ternary functions
|
|
543
|
+
def ternary_op(self, lhs, rhs, other, op):
|
|
544
|
+
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar)
|
|
545
|
+
|
|
546
|
+
create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip)
|
|
547
|
+
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
|
|
548
|
+
|
|
549
|
+
def create_fma(self, x, y, z):
|
|
550
|
+
return TensorHandle(x.data * y.data + z.data, z.dtype.scalar)
|
|
551
|
+
|
|
552
|
+
# unary functions
|
|
553
|
+
def unary_op(self, arg, op):
|
|
554
|
+
return TensorHandle(op(arg.data), arg.dtype.scalar)
|
|
555
|
+
|
|
556
|
+
def create_fabs(self, arg):
|
|
557
|
+
# Mask out the sign bit based on the primitive length
|
|
558
|
+
dtype_tt = arg.dtype
|
|
559
|
+
mask_bitwidth = dtype_tt.primitive_bitwidth - 1
|
|
560
|
+
np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}")
|
|
561
|
+
data = arg.data.view(np_uint_dtype)
|
|
562
|
+
mask = (1 << mask_bitwidth) - 1
|
|
563
|
+
ret = (data & mask).view(_get_np_dtype(dtype_tt))
|
|
564
|
+
return TensorHandle(ret, arg.dtype.scalar)
|
|
565
|
+
|
|
566
|
+
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
|
|
567
|
+
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
|
|
568
|
+
create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2)
|
|
569
|
+
create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
|
|
570
|
+
create_floor = lambda self, arg: self.unary_op(arg, np.floor)
|
|
571
|
+
create_ceil = lambda self, arg: self.unary_op(arg, np.ceil)
|
|
572
|
+
create_log = lambda self, arg: self.unary_op(arg, np.log)
|
|
573
|
+
create_log2 = lambda self, arg: self.unary_op(arg, np.log2)
|
|
574
|
+
create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
|
|
575
|
+
create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
|
|
576
|
+
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
|
|
577
|
+
|
|
578
|
+
def create_erf(self, arg):
|
|
579
|
+
ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data)
|
|
580
|
+
return TensorHandle(ret, arg.dtype.scalar)
|
|
581
|
+
|
|
582
|
+
def create_rsqrt(self, arg):
|
|
583
|
+
return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar)
|
|
584
|
+
|
|
585
|
+
# tensor operators
|
|
586
|
+
create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar)
|
|
587
|
+
|
|
588
|
+
def create_trans(self, arg, perm):
|
|
589
|
+
return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar)
|
|
590
|
+
|
|
591
|
+
def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc):
|
|
592
|
+
a_data = a.data
|
|
593
|
+
b_data = b.data
|
|
594
|
+
if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \
|
|
595
|
+
(b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()):
|
|
596
|
+
a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16)
|
|
597
|
+
b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
|
|
598
|
+
return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
|
|
599
|
+
|
|
600
|
+
def create_make_range(self, ret_ty, start, stop):
|
|
601
|
+
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
|
|
602
|
+
|
|
603
|
+
def create_histogram(self, data, bins, mask):
|
|
604
|
+
if mask is None:
|
|
605
|
+
mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
|
|
606
|
+
# force all masked elements to zero
|
|
607
|
+
data = np.where(mask.data, data.data, np.zeros_like(data.data))
|
|
608
|
+
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
|
|
609
|
+
# remove overcounted elements
|
|
610
|
+
histogram[0] -= np.logical_not(mask.data).sum()
|
|
611
|
+
return TensorHandle(histogram, tl.int32)
|
|
612
|
+
|
|
613
|
+
def create_gather(self, src, indices, axis):
|
|
614
|
+
return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
|
|
615
|
+
|
|
616
|
+
# pointer arithmetic
|
|
617
|
+
|
|
618
|
+
def create_addptr(self, ptr, offset):
|
|
619
|
+
dtype_tt = ptr.get_element_ty()
|
|
620
|
+
element_bitwidth = dtype_tt.primitive_bitwidth
|
|
621
|
+
# int1's bitwidth is 1, but we need to use 8 for pointer arithmetic
|
|
622
|
+
element_bytewidth = max(1, element_bitwidth // 8)
|
|
623
|
+
return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype)
|
|
624
|
+
|
|
625
|
+
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
|
|
626
|
+
is_volatile):
|
|
627
|
+
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
|
628
|
+
dtype_tt = ptrs.get_element_ty()
|
|
629
|
+
dtype_np = _get_np_dtype(dtype_tt)
|
|
630
|
+
if padding_option is None:
|
|
631
|
+
other = None
|
|
632
|
+
elif padding_option == _ir.PADDING_OPTION.PAD_ZERO:
|
|
633
|
+
other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
|
|
634
|
+
elif padding_option == _ir.PADDING_OPTION.PAD_NAN:
|
|
635
|
+
other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
|
|
636
|
+
else:
|
|
637
|
+
raise ValueError(f"unsupported padding option {padding_option}")
|
|
638
|
+
return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
|
|
639
|
+
|
|
640
|
+
def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
|
|
641
|
+
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
|
642
|
+
return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
|
|
643
|
+
|
|
644
|
+
def create_expand_dims(self, arg, axis):
|
|
645
|
+
return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar)
|
|
646
|
+
|
|
647
|
+
def create_broadcast(self, arg, shape):
|
|
648
|
+
return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar)
|
|
649
|
+
|
|
650
|
+
def create_cat(self, lhs, rhs):
|
|
651
|
+
return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar)
|
|
652
|
+
|
|
653
|
+
def create_join(self, lhs, rhs):
|
|
654
|
+
# Triton only supports joining two original tensors into a new one along the last axis
|
|
655
|
+
return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar)
|
|
656
|
+
|
|
657
|
+
def create_split(self, val):
|
|
658
|
+
# Triton only supports splitting the original tensor into two along the last axis
|
|
659
|
+
return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
|
|
660
|
+
|
|
661
|
+
def create_splat(self, ret_ty, arg):
|
|
662
|
+
shape = ret_ty.shape
|
|
663
|
+
if isinstance(arg.dtype, tl.block_type):
|
|
664
|
+
return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
665
|
+
else: # scalar
|
|
666
|
+
return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
667
|
+
|
|
668
|
+
def create_unsplat(self, arg):
|
|
669
|
+
return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
670
|
+
|
|
671
|
+
def create_atomic_cas(self, ptr, cmp, val, sem, scope):
|
|
672
|
+
if sem not in self.ir_sem_to_interpreter_sem:
|
|
673
|
+
raise ValueError(f"unsupported semantic {sem}")
|
|
674
|
+
sem = self.ir_sem_to_interpreter_sem[sem]
|
|
675
|
+
return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar)
|
|
676
|
+
|
|
677
|
+
def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope):
|
|
678
|
+
if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op:
|
|
679
|
+
raise ValueError(f"unsupported rmwOp {rmwOp}")
|
|
680
|
+
if sem not in self.ir_sem_to_interpreter_sem:
|
|
681
|
+
raise ValueError(f"unsupported semantic {sem}")
|
|
682
|
+
rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp]
|
|
683
|
+
sem = self.ir_sem_to_interpreter_sem[sem]
|
|
684
|
+
return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar)
|
|
685
|
+
|
|
686
|
+
def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
|
|
687
|
+
raise NotImplementedError("extern_elementwise not supported in interpreter mode")
|
|
688
|
+
|
|
689
|
+
def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
|
|
690
|
+
raise NotImplementedError("inline_asm not supported in interpreter mode")
|
|
691
|
+
|
|
692
|
+
def create_print(self, prefix, hex, values, isSigned):
|
|
693
|
+
# NOTE: the `isSigned` variable is not really used here; because Signness is already known
|
|
694
|
+
# by `values` themselves in python interpreter, thus not really needed here;
|
|
695
|
+
# it is only used for triton PrintOpToLLVM to correctly construct the format specifier.
|
|
696
|
+
# Interpreter's device_print function has a different format than Triton's device_print
|
|
697
|
+
msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})"
|
|
698
|
+
if prefix:
|
|
699
|
+
msg += f" {prefix}"
|
|
700
|
+
if hex:
|
|
701
|
+
np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"})
|
|
702
|
+
for value in values:
|
|
703
|
+
print(msg + f" {value.data}")
|
|
704
|
+
if hex:
|
|
705
|
+
np.set_printoptions(formatter=None)
|
|
706
|
+
|
|
707
|
+
def create_assert(self, condition, message):
|
|
708
|
+
# Interpreter's device_assert function has a different format than Triton's device_assert
|
|
709
|
+
assert condition, f"{message}"
|
|
710
|
+
|
|
711
|
+
def create_assume(self, condition):
|
|
712
|
+
assert condition, "Assume failed"
|
|
713
|
+
|
|
714
|
+
def create_barrier(self):
|
|
715
|
+
# Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
|
|
716
|
+
pass
|
|
717
|
+
|
|
718
|
+
def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
|
|
719
|
+
# Create new offsets to avoid modifying the original
|
|
720
|
+
new_offsets = [offset.clone() for offset in offsets]
|
|
721
|
+
return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
|
|
722
|
+
|
|
723
|
+
def create_advance(self, ptr, offsets):
|
|
724
|
+
if len(ptr.offsets) != len(offsets):
|
|
725
|
+
raise ValueError("len(ptr.offsets) != len(offsets)")
|
|
726
|
+
# Create new offsets to avoid modifying the original
|
|
727
|
+
new_offsets = [offset.clone() for offset in ptr.offsets]
|
|
728
|
+
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
|
|
729
|
+
for i in range(len(offsets)):
|
|
730
|
+
ret.offsets[i].data += offsets[i].data
|
|
731
|
+
return ret
|
|
732
|
+
|
|
733
|
+
def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
|
|
734
|
+
tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
|
|
735
|
+
desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
|
|
736
|
+
desc.validate()
|
|
737
|
+
return desc
|
|
738
|
+
|
|
739
|
+
def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier,
|
|
740
|
+
eviction_policy):
|
|
741
|
+
assert isinstance(desc, TensorDescHandle)
|
|
742
|
+
ptrs, mask = desc.materialize_pointers(indices)
|
|
743
|
+
dtype_tt = ptrs.get_element_ty()
|
|
744
|
+
dtype_np = _get_np_dtype(dtype_tt)
|
|
745
|
+
padding = desc.padding
|
|
746
|
+
if padding == _ir.PADDING_OPTION.PAD_ZERO:
|
|
747
|
+
other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
|
|
748
|
+
elif padding == _ir.PADDING_OPTION.PAD_NAN:
|
|
749
|
+
other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
|
|
750
|
+
else:
|
|
751
|
+
raise ValueError(f"unsupported padding {padding}")
|
|
752
|
+
return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
|
|
753
|
+
eviction_policy=eviction_policy, is_volatile=False)
|
|
754
|
+
|
|
755
|
+
def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
|
|
756
|
+
ptrs, mask = desc.materialize_pointers(indices)
|
|
757
|
+
return self.create_masked_store(ptrs, value, mask, None, None)
|
|
758
|
+
|
|
759
|
+
def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type):
|
|
760
|
+
dtype = desc.base.dtype.element_ty
|
|
761
|
+
np_dtype = _get_np_dtype(dtype)
|
|
762
|
+
result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype)
|
|
763
|
+
cache_modifier = None
|
|
764
|
+
eviction_policy = None
|
|
765
|
+
for i, x_offset in enumerate(x_offsets.data):
|
|
766
|
+
indices = [TensorHandle(x_offset, tl.int32), y_offset]
|
|
767
|
+
result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data
|
|
768
|
+
return TensorHandle(result, dtype)
|
|
769
|
+
|
|
770
|
+
def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle,
|
|
771
|
+
y_offset: TensorHandle):
|
|
772
|
+
for i, x_offset in enumerate(x_offsets.data):
|
|
773
|
+
slice = TensorHandle(value.data[i], value.dtype)
|
|
774
|
+
indices = [TensorHandle(x_offset, tl.int32), y_offset]
|
|
775
|
+
self.create_descriptor_store(desc, slice, indices)
|
|
776
|
+
|
|
777
|
+
def get_all_ones_value(self, type):
|
|
778
|
+
np_type = _get_np_dtype(type)
|
|
779
|
+
if "int" in np_type.name:
|
|
780
|
+
return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
|
|
781
|
+
elif np_type == np.bool_:
|
|
782
|
+
return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
|
|
783
|
+
else:
|
|
784
|
+
raise TypeError(f"unsupported type {type}")
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def _patch_attr(obj, name, member, builder):
|
|
788
|
+
semantic = TritonSemantic(builder)
|
|
789
|
+
new_member = lambda *args, member=member, **kwargs: (member(*args, **
|
|
790
|
+
{k: v
|
|
791
|
+
for k, v in kwargs.items()
|
|
792
|
+
if k != "_semantic"}, _semantic=semantic))
|
|
793
|
+
setattr(obj, name, new_member)
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def _patch_builtin(pkg, builder):
|
|
797
|
+
for name, member in inspect.getmembers(pkg):
|
|
798
|
+
if tl.core.is_builtin(member):
|
|
799
|
+
_patch_attr(pkg, name, member, builder)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def _patch_lang_tensor(tensor):
|
|
803
|
+
|
|
804
|
+
def _get_bool(self):
|
|
805
|
+
data = self.handle.data
|
|
806
|
+
# in triton, only scalars can be converted to booleans
|
|
807
|
+
# here we need this hack because all scalars are tensors
|
|
808
|
+
return bool(data) if data.size == 1 else True
|
|
809
|
+
|
|
810
|
+
def _get_transpose(self):
|
|
811
|
+
handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype)
|
|
812
|
+
assert self.type.is_block()
|
|
813
|
+
block_shape = list(self.type.shape)
|
|
814
|
+
block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1]
|
|
815
|
+
res_ty = tl.core.block_type(self.dtype, block_shape)
|
|
816
|
+
return tl.core.tensor(handle, res_ty)
|
|
817
|
+
|
|
818
|
+
tensor.__index__ = lambda self: int(self.handle.data)
|
|
819
|
+
tensor.__bool__ = lambda self: _get_bool(self)
|
|
820
|
+
tensor.__repr__ = lambda self: repr(self.handle.data)
|
|
821
|
+
tensor.__str__ = lambda self: str(self.handle.data)
|
|
822
|
+
tensor.T = property(_get_transpose)
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
class ReduceScanOpInterface:
|
|
826
|
+
|
|
827
|
+
def __init__(self, axis, combine_fn):
|
|
828
|
+
self.axis = axis
|
|
829
|
+
self.combine_fn = combine_fn
|
|
830
|
+
|
|
831
|
+
def check_axis(self, shape, axis):
|
|
832
|
+
if axis is not None and axis >= len(shape):
|
|
833
|
+
raise ValueError(f"axis {axis} out of bounds for shape {shape}")
|
|
834
|
+
|
|
835
|
+
def check_tensor(self, input):
|
|
836
|
+
for arg in input:
|
|
837
|
+
if not isinstance(arg, tl.core.tensor):
|
|
838
|
+
raise ValueError(f"input must be a tensor, got {type(arg)}")
|
|
839
|
+
self.check_axis(arg.shape, self.axis)
|
|
840
|
+
|
|
841
|
+
def to_tensor(self, ret, dtype):
|
|
842
|
+
np_dtype = _get_np_dtype(dtype)
|
|
843
|
+
if hasattr(ret, "shape") and ret.shape:
|
|
844
|
+
ret = ret.astype(np_dtype)
|
|
845
|
+
ret_type = tl.block_type(dtype, list(ret.shape))
|
|
846
|
+
else:
|
|
847
|
+
ret = np.array([ret], dtype=np_dtype)
|
|
848
|
+
ret_type = dtype
|
|
849
|
+
return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
|
|
850
|
+
|
|
851
|
+
def apply(self, input):
|
|
852
|
+
if not isinstance(input, tuple):
|
|
853
|
+
return self.apply((input, ))[0]
|
|
854
|
+
self.check_tensor(input)
|
|
855
|
+
ret = self.apply_impl(input)
|
|
856
|
+
return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class ReduceOps(ReduceScanOpInterface):
|
|
860
|
+
|
|
861
|
+
def __init__(self, axis, combine_fn, keep_dims):
|
|
862
|
+
super().__init__(axis, combine_fn)
|
|
863
|
+
self.keep_dims = keep_dims
|
|
864
|
+
|
|
865
|
+
def unravel(self, input, axis):
|
|
866
|
+
ret = []
|
|
867
|
+
for data in input:
|
|
868
|
+
if axis is not None:
|
|
869
|
+
ret.append(data)
|
|
870
|
+
else:
|
|
871
|
+
axis = 0
|
|
872
|
+
ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype))
|
|
873
|
+
return tuple(ret), axis
|
|
874
|
+
|
|
875
|
+
def generic_reduce(self, input):
|
|
876
|
+
original_axis = self.axis
|
|
877
|
+
input, axis = self.unravel(input, self.axis)
|
|
878
|
+
input_data = []
|
|
879
|
+
output_data = []
|
|
880
|
+
input_shape = input[0].handle.data.shape
|
|
881
|
+
output_shape = input_shape[0:axis] + input_shape[axis + 1:]
|
|
882
|
+
for arg in input:
|
|
883
|
+
input_data.append(arg.handle.data)
|
|
884
|
+
output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype))
|
|
885
|
+
# Reduce on axis
|
|
886
|
+
for i in range(input_data[0].size):
|
|
887
|
+
# Recover input_index from i using input_shape
|
|
888
|
+
input_index = np.unravel_index(i, input_shape)
|
|
889
|
+
output_index = input_index[0:axis] + input_index[axis + 1:]
|
|
890
|
+
input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data))
|
|
891
|
+
if input_index[axis] == 0:
|
|
892
|
+
# First element
|
|
893
|
+
for j in range(len(output_data)):
|
|
894
|
+
output_data[j][output_index] = input_tuple[j].handle.data.item()
|
|
895
|
+
else:
|
|
896
|
+
acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data))
|
|
897
|
+
combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple)
|
|
898
|
+
acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
|
|
899
|
+
for j in range(len(output_data)):
|
|
900
|
+
output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance(
|
|
901
|
+
acc_tuple[j], tl.core.tensor) else acc_tuple[j]
|
|
902
|
+
# Pack output
|
|
903
|
+
ret = []
|
|
904
|
+
for i, data in enumerate(output_data):
|
|
905
|
+
if self.keep_dims:
|
|
906
|
+
if original_axis is not None:
|
|
907
|
+
data = np.expand_dims(data, axis)
|
|
908
|
+
else:
|
|
909
|
+
for _ in range(len(input_shape)):
|
|
910
|
+
data = np.expand_dims(data, 0)
|
|
911
|
+
|
|
912
|
+
elif original_axis is None:
|
|
913
|
+
# Take a scalar
|
|
914
|
+
data = data.item()
|
|
915
|
+
ret.append(self.to_tensor(data, input[i].dtype))
|
|
916
|
+
return ret
|
|
917
|
+
|
|
918
|
+
def min_max(self, input, val_reduce_op, idx_reduce_op=None):
|
|
919
|
+
# If input is a tuple, it must be (val, index), and we only take val
|
|
920
|
+
input = input[0] if isinstance(input, tuple) else input
|
|
921
|
+
val = None
|
|
922
|
+
idx = None
|
|
923
|
+
if val_reduce_op:
|
|
924
|
+
val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
|
|
925
|
+
if idx_reduce_op:
|
|
926
|
+
idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32)
|
|
927
|
+
if val is not None and idx is not None:
|
|
928
|
+
return val, idx
|
|
929
|
+
elif val is not None:
|
|
930
|
+
return val
|
|
931
|
+
elif idx is not None:
|
|
932
|
+
return idx
|
|
933
|
+
else:
|
|
934
|
+
raise ValueError("val_reduce_op and idx_reduce_op are both None")
|
|
935
|
+
|
|
936
|
+
def sum(self, input):
|
|
937
|
+
return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
|
|
938
|
+
|
|
939
|
+
def apply_impl(self, input):
|
|
940
|
+
if self.combine_fn == tl.standard._argmin_combine_tie_break_left:
|
|
941
|
+
return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin)
|
|
942
|
+
elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
|
|
943
|
+
return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
|
|
944
|
+
elif self.combine_fn == tl.standard._elementwise_max:
|
|
945
|
+
return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
|
|
946
|
+
elif self.combine_fn == tl.standard._elementwise_min:
|
|
947
|
+
return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
|
|
948
|
+
elif self.combine_fn == tl.standard._sum_combine:
|
|
949
|
+
return self.sum(input[0])
|
|
950
|
+
else:
|
|
951
|
+
# Fall back to the slow mode
|
|
952
|
+
return self.generic_reduce(input)
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
class ScanOps(ReduceScanOpInterface):
|
|
956
|
+
|
|
957
|
+
def __init__(self, axis, combine_fn, reverse):
|
|
958
|
+
super().__init__(axis, combine_fn)
|
|
959
|
+
self.reverse = reverse
|
|
960
|
+
|
|
961
|
+
def cumsum(self, input):
|
|
962
|
+
return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)]
|
|
963
|
+
|
|
964
|
+
def cumprod(self, input):
|
|
965
|
+
return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)]
|
|
966
|
+
|
|
967
|
+
def generic_scan(self, input):
|
|
968
|
+
input_data = []
|
|
969
|
+
output_data = []
|
|
970
|
+
shape = input[0].handle.data.shape
|
|
971
|
+
for arg in input:
|
|
972
|
+
input_data.append(arg.handle.data)
|
|
973
|
+
output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype))
|
|
974
|
+
# Scan on axis
|
|
975
|
+
for i in range(input_data[0].size):
|
|
976
|
+
# Recover index from i using shape
|
|
977
|
+
index = np.unravel_index(i, shape)
|
|
978
|
+
data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data))
|
|
979
|
+
if index[self.axis] == 0:
|
|
980
|
+
# First element
|
|
981
|
+
for j in range(len(output_data)):
|
|
982
|
+
output_data[j][index] = data[j].handle.data.item()
|
|
983
|
+
else:
|
|
984
|
+
prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index)))
|
|
985
|
+
acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data))
|
|
986
|
+
combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data)
|
|
987
|
+
acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
|
|
988
|
+
for j in range(len(output_data)):
|
|
989
|
+
output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance(
|
|
990
|
+
acc_tuple[j], tl.core.tensor) else acc_tuple[j]
|
|
991
|
+
# Pack output
|
|
992
|
+
ret = []
|
|
993
|
+
for i, data in enumerate(output_data):
|
|
994
|
+
ret.append(self.to_tensor(data, input[i].dtype))
|
|
995
|
+
return ret
|
|
996
|
+
|
|
997
|
+
def apply_impl(self, input):
|
|
998
|
+
new_input = []
|
|
999
|
+
if self.reverse:
|
|
1000
|
+
for arg in input:
|
|
1001
|
+
new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype))
|
|
1002
|
+
else:
|
|
1003
|
+
new_input = input
|
|
1004
|
+
if self.combine_fn == tl.standard._sum_combine:
|
|
1005
|
+
ret = self.cumsum(new_input[0])
|
|
1006
|
+
elif self.combine_fn == tl.standard._prod_combine:
|
|
1007
|
+
ret = self.cumprod(new_input[0])
|
|
1008
|
+
else:
|
|
1009
|
+
# Fall back to the slow mode
|
|
1010
|
+
ret = self.generic_scan(new_input)
|
|
1011
|
+
if self.reverse:
|
|
1012
|
+
for arg in ret:
|
|
1013
|
+
arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
|
|
1014
|
+
return ret
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
def _patch_reduce_scan():
|
|
1018
|
+
# Because interpreter doesn't support region_builder_fn, we cannot patch the builder
|
|
1019
|
+
# to use the new reduce and scan functions.
|
|
1020
|
+
# Instead, we need to patch reduce and reduce functions in tl and tl.core
|
|
1021
|
+
def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs):
|
|
1022
|
+
return ReduceOps(axis, combine_fn, keep_dims).apply(input)
|
|
1023
|
+
|
|
1024
|
+
def _new_scan(input, axis, combine_fn, reverse=False, **kwargs):
|
|
1025
|
+
return ScanOps(axis, combine_fn, reverse).apply(input)
|
|
1026
|
+
|
|
1027
|
+
tl.reduce = _new_reduce
|
|
1028
|
+
tl.associative_scan = _new_scan
|
|
1029
|
+
tl.core.reduce = _new_reduce
|
|
1030
|
+
tl.core.associative_scan = _new_scan
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
def _patch_lang_core(lang):
|
|
1034
|
+
|
|
1035
|
+
def _new_to_ir(self, builder):
|
|
1036
|
+
# We need to specify signedness for integer types in the numpy mode
|
|
1037
|
+
if self.name == 'void':
|
|
1038
|
+
return builder.get_void_ty()
|
|
1039
|
+
elif self.name == 'int1':
|
|
1040
|
+
return builder.get_int1_ty()
|
|
1041
|
+
elif self.name == 'int8':
|
|
1042
|
+
return builder.get_int8_ty()
|
|
1043
|
+
elif self.name == 'uint8':
|
|
1044
|
+
return builder.get_uint8_ty()
|
|
1045
|
+
elif self.name == 'int16':
|
|
1046
|
+
return builder.get_int16_ty()
|
|
1047
|
+
elif self.name == 'uint16':
|
|
1048
|
+
return builder.get_uint16_ty()
|
|
1049
|
+
elif self.name == 'int32':
|
|
1050
|
+
return builder.get_int32_ty()
|
|
1051
|
+
elif self.name == 'uint32':
|
|
1052
|
+
return builder.get_uint32_ty()
|
|
1053
|
+
elif self.name == 'int64':
|
|
1054
|
+
return builder.get_int64_ty()
|
|
1055
|
+
elif self.name == 'uint64':
|
|
1056
|
+
return builder.get_uint64_ty()
|
|
1057
|
+
elif self.name == 'fp8e5':
|
|
1058
|
+
return builder.get_fp8e5_ty()
|
|
1059
|
+
elif self.name == 'fp8e4nv':
|
|
1060
|
+
return builder.get_fp8e4nv_ty()
|
|
1061
|
+
elif self.name == 'fp8e4b15':
|
|
1062
|
+
return builder.get_fp8e4b15_ty()
|
|
1063
|
+
elif self.name == 'fp16':
|
|
1064
|
+
return builder.get_half_ty()
|
|
1065
|
+
elif self.name == 'bf16':
|
|
1066
|
+
return builder.get_bf16_ty()
|
|
1067
|
+
elif self.name == 'fp32':
|
|
1068
|
+
return builder.get_float_ty()
|
|
1069
|
+
elif self.name == 'fp64':
|
|
1070
|
+
return builder.get_double_ty()
|
|
1071
|
+
raise ValueError(f'fail to convert {self} to ir type')
|
|
1072
|
+
|
|
1073
|
+
# can't just map lang.static_range to `range`, because `tl.static_range`
|
|
1074
|
+
# can get `step` passed by keyword
|
|
1075
|
+
def _new_range(arg1, arg2=None, step=None, **kwargs):
|
|
1076
|
+
if step is None:
|
|
1077
|
+
step = 1
|
|
1078
|
+
if arg2 is None:
|
|
1079
|
+
start, end = 0, arg1
|
|
1080
|
+
else:
|
|
1081
|
+
start, end = arg1, arg2
|
|
1082
|
+
return range(start, end, step)
|
|
1083
|
+
|
|
1084
|
+
def _new_static_assert(cond, msg=""):
|
|
1085
|
+
assert cond, msg
|
|
1086
|
+
|
|
1087
|
+
def _set_attr(input, values, name):
|
|
1088
|
+
# skip non tensor types. This may happen for induction variables.
|
|
1089
|
+
if not isinstance(input, tl.tensor):
|
|
1090
|
+
return input
|
|
1091
|
+
# Unwrap constexpr
|
|
1092
|
+
values = [values] if not isinstance(values, (list, tuple)) else values
|
|
1093
|
+
values = [v.value if isinstance(v, tl.constexpr) else v for v in values]
|
|
1094
|
+
if len(values) != max(1, len(input.shape)):
|
|
1095
|
+
raise ValueError(f"len(values) != len(input.shape) for {name}")
|
|
1096
|
+
input.handle.set_attr(name, values)
|
|
1097
|
+
return input
|
|
1098
|
+
|
|
1099
|
+
lang.range = _new_range
|
|
1100
|
+
lang.static_range = _new_range
|
|
1101
|
+
lang.static_assert = _new_static_assert
|
|
1102
|
+
lang.static_print = print
|
|
1103
|
+
lang.dtype.to_ir = _new_to_ir
|
|
1104
|
+
lang.multiple_of = partial(_set_attr, name="tt.divisibility")
|
|
1105
|
+
lang.max_contiguous = partial(_set_attr, name="tt.contiguity")
|
|
1106
|
+
lang.max_constancy = partial(_set_attr, name="tt.constancy")
|
|
1107
|
+
|
|
1108
|
+
_patch_reduce_scan()
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
def _patch_lang(fn):
|
|
1112
|
+
langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]]
|
|
1113
|
+
assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
|
|
1114
|
+
for lang in langs:
|
|
1115
|
+
_patch_builtin(lang, interpreter_builder)
|
|
1116
|
+
_patch_builtin(lang.tensor, interpreter_builder)
|
|
1117
|
+
if lang == tl:
|
|
1118
|
+
_patch_builtin(lang.math, interpreter_builder)
|
|
1119
|
+
_patch_lang_tensor(lang.tensor)
|
|
1120
|
+
_patch_lang_core(lang)
|
|
1121
|
+
_patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder)
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
def _tuple_create(arg, contents):
|
|
1125
|
+
# NamedTuples and tuples have different construction semantics. NamedTuple
|
|
1126
|
+
# has a constructor that takes individual arguments, while tuple takes an
|
|
1127
|
+
# iterable. Both have type "tuple" making it difficult to distinguish
|
|
1128
|
+
# between them, but only NamedTuple has "_fields" and apparently this is how
|
|
1129
|
+
# everyone does the check.
|
|
1130
|
+
return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents)
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
# TODO: wrap everything in triton tensors
|
|
1134
|
+
def _implicit_cvt(arg):
|
|
1135
|
+
if isinstance(arg, int):
|
|
1136
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
|
|
1137
|
+
dtype = np.int32
|
|
1138
|
+
if -2**31 <= arg < 2**31:
|
|
1139
|
+
dtype = np.int32
|
|
1140
|
+
elif 2**31 <= arg < 2**32:
|
|
1141
|
+
dtype = np.uint32
|
|
1142
|
+
elif -2**63 <= arg < 2**63:
|
|
1143
|
+
dtype = np.int64
|
|
1144
|
+
elif 2**63 <= arg < 2**64:
|
|
1145
|
+
dtype = np.uint64
|
|
1146
|
+
else:
|
|
1147
|
+
raise ValueError(f"Unsupported integer value {arg}")
|
|
1148
|
+
handle = TensorHandle(np.array([arg], dtype=dtype), ty)
|
|
1149
|
+
return tl.tensor(handle, ty)
|
|
1150
|
+
if hasattr(arg, "data_ptr"):
|
|
1151
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
|
|
1152
|
+
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
|
1153
|
+
return tl.tensor(handle, ty)
|
|
1154
|
+
elif isinstance(arg, tuple):
|
|
1155
|
+
return _tuple_create(arg, map(_implicit_cvt, arg))
|
|
1156
|
+
elif isinstance(arg, TensorDescriptor):
|
|
1157
|
+
strides = [_implicit_cvt(s) for s in arg.strides]
|
|
1158
|
+
assert arg.strides[-1] == 1
|
|
1159
|
+
strides[-1] = tl.constexpr(1)
|
|
1160
|
+
semantic = TritonSemantic(InterpreterBuilder())
|
|
1161
|
+
return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
|
|
1162
|
+
shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
|
|
1163
|
+
block_shape=[tl.constexpr(b)
|
|
1164
|
+
for b in arg.block_shape], padding_option=arg.padding)
|
|
1165
|
+
return arg
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
interpreter_builder = InterpreterBuilder()
|
|
1169
|
+
interpreter_semantic = TritonSemantic(interpreter_builder)
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
def _unwrap_tensor(t):
|
|
1173
|
+
if isinstance(t, triton.runtime.jit.TensorWrapper):
|
|
1174
|
+
return t.base
|
|
1175
|
+
return t
|
|
1176
|
+
|
|
1177
|
+
|
|
1178
|
+
def _rewrap_tensor(t, original_tensor):
|
|
1179
|
+
if isinstance(original_tensor, triton.runtime.jit.TensorWrapper):
|
|
1180
|
+
return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype)
|
|
1181
|
+
return t
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
class GridExecutor:
|
|
1185
|
+
|
|
1186
|
+
def __init__(self, fn, arg_names, grid):
|
|
1187
|
+
from .jit import _normalize_ty # TODO: modularize
|
|
1188
|
+
|
|
1189
|
+
self.fn = fn
|
|
1190
|
+
self.arg_names = arg_names
|
|
1191
|
+
self.grid = grid
|
|
1192
|
+
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
|
1193
|
+
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
|
|
1194
|
+
|
|
1195
|
+
def _init_args_hst(self, args_dev, kwargs):
|
|
1196
|
+
storages = {}
|
|
1197
|
+
|
|
1198
|
+
def _to_cpu(arg):
|
|
1199
|
+
if isinstance(arg, tuple):
|
|
1200
|
+
return _tuple_create(arg, map(_to_cpu, arg))
|
|
1201
|
+
elif isinstance(arg, TensorDescriptor):
|
|
1202
|
+
return TensorDescriptor(
|
|
1203
|
+
_to_cpu(arg.base),
|
|
1204
|
+
arg.shape,
|
|
1205
|
+
arg.strides,
|
|
1206
|
+
arg.block_shape,
|
|
1207
|
+
arg.padding,
|
|
1208
|
+
)
|
|
1209
|
+
elif not hasattr(arg, "data_ptr"):
|
|
1210
|
+
return arg
|
|
1211
|
+
|
|
1212
|
+
unwrapped_arg = _unwrap_tensor(arg)
|
|
1213
|
+
if unwrapped_arg.untyped_storage().data_ptr() not in storages:
|
|
1214
|
+
storage = unwrapped_arg.untyped_storage()
|
|
1215
|
+
storages[storage.data_ptr()] = storage.cpu()
|
|
1216
|
+
|
|
1217
|
+
storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
|
|
1218
|
+
cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
|
|
1219
|
+
cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
|
|
1220
|
+
cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
|
|
1221
|
+
return cpu_arg
|
|
1222
|
+
|
|
1223
|
+
args_hst = [_to_cpu(arg) for arg in args_dev]
|
|
1224
|
+
|
|
1225
|
+
# Process keyword arguments
|
|
1226
|
+
kwargs_hst = {}
|
|
1227
|
+
for key, value in kwargs.items():
|
|
1228
|
+
kwargs_hst[key] = _to_cpu(value)
|
|
1229
|
+
return args_hst, kwargs_hst
|
|
1230
|
+
|
|
1231
|
+
def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
|
|
1232
|
+
storages = {}
|
|
1233
|
+
|
|
1234
|
+
def _from_cpu(arg_dev, arg_hst):
|
|
1235
|
+
if hasattr(arg_dev, "data_ptr"):
|
|
1236
|
+
# No need to rewrap because this just modifies internal
|
|
1237
|
+
arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
|
|
1238
|
+
storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
|
|
1239
|
+
elif isinstance(arg_dev, tuple):
|
|
1240
|
+
for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
|
|
1241
|
+
_from_cpu(arg_dev, arg_hst)
|
|
1242
|
+
elif isinstance(arg_dev, TensorDescriptor):
|
|
1243
|
+
_from_cpu(arg_dev.base, arg_hst.base)
|
|
1244
|
+
|
|
1245
|
+
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
|
1246
|
+
_from_cpu(arg_dev, arg_hst)
|
|
1247
|
+
|
|
1248
|
+
# Restore keyword arguments
|
|
1249
|
+
for key, kwarg_dev in kwargs.items():
|
|
1250
|
+
kwarg_hst = kwargs_hst[key]
|
|
1251
|
+
_from_cpu(kwarg_dev, kwarg_hst)
|
|
1252
|
+
|
|
1253
|
+
for (arg_dev, arg_hst) in storages.values():
|
|
1254
|
+
arg_dev.copy_(arg_hst)
|
|
1255
|
+
|
|
1256
|
+
def __call__(self, *args_dev, **kwargs):
|
|
1257
|
+
if kwargs.pop("warmup", False):
|
|
1258
|
+
return
|
|
1259
|
+
# Removes not used reserved keywords from kwargs
|
|
1260
|
+
# Triton doesn't support keyword-only, variable positional or variable keyword arguments
|
|
1261
|
+
# It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
|
|
1262
|
+
argspec = inspect.getfullargspec(self.fn)
|
|
1263
|
+
kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
|
|
1264
|
+
# copy arguments to the host
|
|
1265
|
+
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
|
|
1266
|
+
# remaps core language functions to interpreted ones
|
|
1267
|
+
_patch_lang(self.fn)
|
|
1268
|
+
# we need to copy arguments to the host for the interpreter
|
|
1269
|
+
# implicitly convert tensor arguments to their base pointers
|
|
1270
|
+
args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
|
|
1271
|
+
args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
|
|
1272
|
+
# iterate through grid
|
|
1273
|
+
grid = self.grid(args) if callable(self.grid) else self.grid
|
|
1274
|
+
assert len(grid) <= 3, "grid must have at most 3 dimensions"
|
|
1275
|
+
grid = grid + (1, ) * (3 - len(grid))
|
|
1276
|
+
interpreter_builder.set_grid_dim(*grid)
|
|
1277
|
+
try:
|
|
1278
|
+
for x in range(grid[0]):
|
|
1279
|
+
for y in range(grid[1]):
|
|
1280
|
+
for z in range(grid[2]):
|
|
1281
|
+
interpreter_builder.set_grid_idx(x, y, z)
|
|
1282
|
+
self.fn(**args)
|
|
1283
|
+
except Exception as e:
|
|
1284
|
+
if triton.knobs.compilation.front_end_debugging:
|
|
1285
|
+
raise
|
|
1286
|
+
raise InterpreterError(repr(e)) from e
|
|
1287
|
+
# copy arguments back to propagate side-effects
|
|
1288
|
+
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
|
|
1289
|
+
|
|
1290
|
+
|
|
1291
|
+
class ASTTransformer(ast.NodeTransformer):
|
|
1292
|
+
|
|
1293
|
+
def visit_Assign(self, node):
|
|
1294
|
+
names = []
|
|
1295
|
+
for target in node.targets:
|
|
1296
|
+
names += [self.visit(target)]
|
|
1297
|
+
if len(names) > 1:
|
|
1298
|
+
raise ValueError("Multiple assignments are not supported")
|
|
1299
|
+
# Modify the assignment x = value to
|
|
1300
|
+
# interpreter_semantic.to_tensor(value, False)
|
|
1301
|
+
node.value = ast.Call(
|
|
1302
|
+
func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
|
|
1303
|
+
ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
|
|
1304
|
+
return node
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
class FunctionRewriter:
|
|
1308
|
+
ast_transformer = ASTTransformer()
|
|
1309
|
+
|
|
1310
|
+
def __init__(self, fn, **kwargs):
|
|
1311
|
+
self.fn = fn
|
|
1312
|
+
self.kwargs = kwargs
|
|
1313
|
+
self.filename: str = ""
|
|
1314
|
+
# Absolute line number in the file
|
|
1315
|
+
self.def_file_lineno: int = 0
|
|
1316
|
+
|
|
1317
|
+
def rewrite_ast(self):
|
|
1318
|
+
# If exception is raise, it means the function does not have source code available,
|
|
1319
|
+
# e.g., dynamically generated functions, we cannot rewrite it so just return the original function
|
|
1320
|
+
try:
|
|
1321
|
+
lines, _ = inspect.getsourcelines(self.fn)
|
|
1322
|
+
except Exception:
|
|
1323
|
+
return self.fn
|
|
1324
|
+
|
|
1325
|
+
# truncate lines before def
|
|
1326
|
+
# @triton.autotune(...)
|
|
1327
|
+
# ...
|
|
1328
|
+
# @triton.jit
|
|
1329
|
+
# ...
|
|
1330
|
+
# def foo(...): <- this line is the function definition
|
|
1331
|
+
self.filename, self.def_file_lineno = self._get_jit_fn_file_line()
|
|
1332
|
+
self.def_lineno = self._find_def(lines)
|
|
1333
|
+
src = self._prepare_source(lines)
|
|
1334
|
+
transformed_ast = self._transform_ast(src)
|
|
1335
|
+
return self._compile_and_exec(transformed_ast)
|
|
1336
|
+
|
|
1337
|
+
def _get_jit_fn_file_line(self):
|
|
1338
|
+
from .jit import get_jit_fn_file_line, JITFunction
|
|
1339
|
+
return get_jit_fn_file_line(JITFunction(self.fn))
|
|
1340
|
+
|
|
1341
|
+
def _find_def(self, lines):
|
|
1342
|
+
def_lineno = 0
|
|
1343
|
+
# Line numbers start from 1
|
|
1344
|
+
for i, line in enumerate(lines):
|
|
1345
|
+
if line.strip().startswith("def "):
|
|
1346
|
+
def_lineno = i + 1
|
|
1347
|
+
return def_lineno
|
|
1348
|
+
|
|
1349
|
+
def _prepare_source(self, lines):
|
|
1350
|
+
lines = lines[self.def_lineno - 1:]
|
|
1351
|
+
src = ''.join(lines)
|
|
1352
|
+
return textwrap.dedent(src)
|
|
1353
|
+
|
|
1354
|
+
def _transform_ast(self, src):
|
|
1355
|
+
# src is like:
|
|
1356
|
+
# 1: def foo(...):
|
|
1357
|
+
# 2: ...
|
|
1358
|
+
parsed_ast = ast.parse(src)
|
|
1359
|
+
transformed_ast = self.ast_transformer.visit(parsed_ast)
|
|
1360
|
+
ast.fix_missing_locations(transformed_ast)
|
|
1361
|
+
inc_lineno = self.def_file_lineno - 1
|
|
1362
|
+
ast.increment_lineno(transformed_ast, inc_lineno)
|
|
1363
|
+
return transformed_ast
|
|
1364
|
+
|
|
1365
|
+
def _compile_and_exec(self, transformed_ast):
|
|
1366
|
+
compiled_code = compile(transformed_ast, filename=self.filename, mode='exec')
|
|
1367
|
+
local_namespace = {**self.kwargs}
|
|
1368
|
+
fn_globals = self.fn.__globals__
|
|
1369
|
+
for key, value in globals().items():
|
|
1370
|
+
if key not in fn_globals:
|
|
1371
|
+
fn_globals[key] = value
|
|
1372
|
+
exec(compiled_code, fn_globals, local_namespace)
|
|
1373
|
+
return local_namespace[self.fn.__name__]
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
class InterpretedFunction:
|
|
1377
|
+
# Cache all rewritten functions
|
|
1378
|
+
rewritten_fn: Dict[Callable, Callable] = {}
|
|
1379
|
+
|
|
1380
|
+
def __init__(self, fn, **kwargs) -> None:
|
|
1381
|
+
self.fn = fn
|
|
1382
|
+
self.rewriter = FunctionRewriter(fn, **kwargs)
|
|
1383
|
+
self.kwargs = kwargs
|
|
1384
|
+
|
|
1385
|
+
def run(*args, **kwargs):
|
|
1386
|
+
grid = kwargs["grid"]
|
|
1387
|
+
fn = self.rewrite()
|
|
1388
|
+
return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)
|
|
1389
|
+
|
|
1390
|
+
self.run = run
|
|
1391
|
+
signature = inspect.signature(fn)
|
|
1392
|
+
self.arg_names = [v.name for v in signature.parameters.values()]
|
|
1393
|
+
|
|
1394
|
+
def rewrite(self):
|
|
1395
|
+
if self.fn not in self.rewritten_fn:
|
|
1396
|
+
self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast()
|
|
1397
|
+
return self.rewritten_fn[self.fn]
|
|
1398
|
+
|
|
1399
|
+
@property
|
|
1400
|
+
def __name__(self):
|
|
1401
|
+
return self.fn.__name__
|
|
1402
|
+
|
|
1403
|
+
def __getitem__(self, grid):
|
|
1404
|
+
fn = self.rewrite()
|
|
1405
|
+
return GridExecutor(fn, self.arg_names, grid)
|
|
1406
|
+
|
|
1407
|
+
def __call__(self, *args, **kwargs):
|
|
1408
|
+
# This is a device function call
|
|
1409
|
+
_patch_lang(self.fn)
|
|
1410
|
+
fn = self.rewrite()
|
|
1411
|
+
try:
|
|
1412
|
+
return fn(*args, **kwargs)
|
|
1413
|
+
except Exception as e:
|
|
1414
|
+
raise InterpreterError(repr(e)) from e
|