triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
from typing import Sequence, List, TypeVar, Tuple, Callable
|
|
2
|
+
import math
|
|
3
|
+
from triton.language.semantic import TritonSemantic
|
|
4
|
+
from . import _core as ttgl
|
|
5
|
+
from ._layouts import AutoLayout, DistributedLayout, SliceLayout
|
|
6
|
+
from triton._C.libtriton.gluon_ir import GluonOpBuilder
|
|
7
|
+
from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
|
|
8
|
+
|
|
9
|
+
TensorTy = TypeVar("TensorTy")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
|
|
13
|
+
if not cond:
|
|
14
|
+
raise category(msg_fn())
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GluonCallerContext:
|
|
18
|
+
|
|
19
|
+
def __init__(self, num_warps: int):
|
|
20
|
+
self.num_warps = num_warps
|
|
21
|
+
|
|
22
|
+
def mangle(self):
|
|
23
|
+
return f"_NW{self.num_warps}"
|
|
24
|
+
|
|
25
|
+
def initialize_callee(self, fn, builder):
|
|
26
|
+
fn.set_attr("ttg.num-warps", builder.get_int32_attr(self.num_warps))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GluonSemantic(TritonSemantic[TensorTy]):
|
|
30
|
+
tensor = ttgl.tensor
|
|
31
|
+
lang = ttgl
|
|
32
|
+
|
|
33
|
+
builder: GluonOpBuilder
|
|
34
|
+
|
|
35
|
+
def __init__(self, builder: GluonOpBuilder):
|
|
36
|
+
self.builder = builder
|
|
37
|
+
|
|
38
|
+
def _wrap_handle_infer_layout(self, handle, scalar_ty, shape):
|
|
39
|
+
if shape == []:
|
|
40
|
+
ty = scalar_ty
|
|
41
|
+
else:
|
|
42
|
+
ty = ttgl.distributed_type(scalar_ty, shape, self.builder.get_gluon_layout_from_tensor(handle))
|
|
43
|
+
return self.tensor(handle, ty)
|
|
44
|
+
|
|
45
|
+
def _wrap_tensor_infer_layout(self, tensor):
|
|
46
|
+
return self._wrap_handle_infer_layout(tensor.handle, tensor.type.scalar, tensor.shape)
|
|
47
|
+
|
|
48
|
+
def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
|
|
49
|
+
if len(lhs_shape) != len(rhs_shape):
|
|
50
|
+
raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")
|
|
51
|
+
|
|
52
|
+
ret_shape = []
|
|
53
|
+
for i, left in enumerate(lhs_shape):
|
|
54
|
+
right = rhs_shape[i]
|
|
55
|
+
if left == 1:
|
|
56
|
+
ret_shape.append(right)
|
|
57
|
+
elif (right == 1) or (right == left):
|
|
58
|
+
ret_shape.append(left)
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
|
|
61
|
+
"at index " + str(i) + ": " + str(left) + " and " + str(right))
|
|
62
|
+
return ret_shape
|
|
63
|
+
|
|
64
|
+
def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
|
|
65
|
+
dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape]
|
|
66
|
+
dst_shape.insert(axis, 1)
|
|
67
|
+
|
|
68
|
+
if axis < 0:
|
|
69
|
+
axis += len(input.shape)
|
|
70
|
+
|
|
71
|
+
_check(isinstance(input.type, ttgl.distributed_type),
|
|
72
|
+
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
|
|
73
|
+
layout = input.type.layout
|
|
74
|
+
_check(isinstance(layout, (SliceLayout, AutoLayout)),
|
|
75
|
+
lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
|
|
76
|
+
_check(
|
|
77
|
+
isinstance(layout, AutoLayout) or layout.dim == axis,
|
|
78
|
+
lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
|
|
79
|
+
|
|
80
|
+
handle = self.builder.create_expand_dims(input.handle, axis)
|
|
81
|
+
return self._wrap_handle_infer_layout(handle, input.type.scalar, dst_shape)
|
|
82
|
+
|
|
83
|
+
def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
|
|
84
|
+
a, b = self.broadcast_impl_value(a, b)
|
|
85
|
+
_check(a.shape != [], "Cannot join scalars in gluon")
|
|
86
|
+
value = super().join(a, b)
|
|
87
|
+
return self._wrap_tensor_infer_layout(value)
|
|
88
|
+
|
|
89
|
+
def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
|
|
90
|
+
lhs, rhs = super().split(a)
|
|
91
|
+
return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs)
|
|
92
|
+
|
|
93
|
+
def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
|
|
94
|
+
value = super().permute(input, dims)
|
|
95
|
+
return self._wrap_tensor_infer_layout(value)
|
|
96
|
+
|
|
97
|
+
def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
|
|
98
|
+
_check(isinstance(input.type, ttgl.distributed_type),
|
|
99
|
+
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
|
|
100
|
+
src_shape = input.type.get_block_shapes()
|
|
101
|
+
_check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
|
|
102
|
+
if shape == src_shape:
|
|
103
|
+
return input
|
|
104
|
+
for i, item in enumerate(src_shape):
|
|
105
|
+
if shape[i] != item and item != 1:
|
|
106
|
+
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
|
107
|
+
f" must match the existing size ({item}) at non-singleton dimension"
|
|
108
|
+
f" {i}: {src_shape}, {shape}")
|
|
109
|
+
ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout)
|
|
110
|
+
handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder))
|
|
111
|
+
return self.tensor(handle, ret_ty)
|
|
112
|
+
|
|
113
|
+
def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
|
|
114
|
+
lhs_ty = lhs.type
|
|
115
|
+
rhs_ty = rhs.type
|
|
116
|
+
|
|
117
|
+
if not lhs_ty.is_block() or not rhs_ty.is_block():
|
|
118
|
+
return super().broadcast_impl_value(lhs, rhs)
|
|
119
|
+
|
|
120
|
+
_check(isinstance(lhs_ty, ttgl.distributed_type),
|
|
121
|
+
lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}")
|
|
122
|
+
_check(isinstance(rhs_ty, ttgl.distributed_type),
|
|
123
|
+
lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}")
|
|
124
|
+
|
|
125
|
+
lhs_shape = lhs_ty.get_block_shapes()
|
|
126
|
+
rhs_shape = rhs_ty.get_block_shapes()
|
|
127
|
+
ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
|
|
128
|
+
|
|
129
|
+
is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout)
|
|
130
|
+
is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout)
|
|
131
|
+
if is_lhs_auto and not is_rhs_auto:
|
|
132
|
+
lhs = self.set_auto_layout(lhs, rhs_ty.layout)
|
|
133
|
+
elif is_rhs_auto and not is_lhs_auto:
|
|
134
|
+
rhs = self.set_auto_layout(rhs, lhs_ty.layout)
|
|
135
|
+
elif lhs_ty.layout != rhs_ty.layout:
|
|
136
|
+
raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
|
|
137
|
+
|
|
138
|
+
lhs = self.broadcast_impl_shape(lhs, ret_shape)
|
|
139
|
+
rhs = self.broadcast_impl_shape(rhs, ret_shape)
|
|
140
|
+
return lhs, rhs
|
|
141
|
+
|
|
142
|
+
def arange(self, start, end, layout):
|
|
143
|
+
shape = [end - start]
|
|
144
|
+
if layout is None:
|
|
145
|
+
layout = AutoLayout()
|
|
146
|
+
ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
|
|
147
|
+
return super().arange(start, end, ret_ty=ret_ty)
|
|
148
|
+
|
|
149
|
+
def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
|
|
150
|
+
_check(not can_reorder, "can_reorder is not supported in gluon")
|
|
151
|
+
value = super().reshape(input, dst_shape, can_reorder)
|
|
152
|
+
return self._wrap_tensor_infer_layout(value)
|
|
153
|
+
|
|
154
|
+
def splat(self, value, shape, layout):
|
|
155
|
+
ret_ty = ttgl.distributed_type(value.dtype, shape, layout)
|
|
156
|
+
handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle)
|
|
157
|
+
return ttgl.tensor(handle, ret_ty)
|
|
158
|
+
|
|
159
|
+
def full(self, shape, value, dtype, layout):
|
|
160
|
+
scalar = self.make_scalar(value, dtype)
|
|
161
|
+
if layout is None:
|
|
162
|
+
layout = AutoLayout()
|
|
163
|
+
return self.splat(scalar, shape, layout)
|
|
164
|
+
|
|
165
|
+
def convert_layout(self, value, layout, assert_trivial=False):
|
|
166
|
+
ty = value.type
|
|
167
|
+
_check(isinstance(ty, ttgl.distributed_type),
|
|
168
|
+
lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
|
|
169
|
+
ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
|
|
170
|
+
ret_ty_ir = ret_ty.to_ir(self.builder)
|
|
171
|
+
if assert_trivial and not self.builder.is_convert_layout_trivial(ret_ty_ir, value.handle):
|
|
172
|
+
raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial")
|
|
173
|
+
handle = self.builder.create_convert_layout(ret_ty_ir, value.handle)
|
|
174
|
+
return ttgl.tensor(handle, ret_ty)
|
|
175
|
+
|
|
176
|
+
def allocate_shared(self, element_ty, shape, layout, value):
|
|
177
|
+
ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape)
|
|
178
|
+
if value is not None:
|
|
179
|
+
handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle)
|
|
180
|
+
else:
|
|
181
|
+
handle = self.builder.create_local_alloc(ty.to_ir(self.builder))
|
|
182
|
+
return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape)
|
|
183
|
+
|
|
184
|
+
def shared_load(self, mem_desc, layout):
|
|
185
|
+
ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout)
|
|
186
|
+
handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle)
|
|
187
|
+
return ttgl.tensor(handle, ret_ty)
|
|
188
|
+
|
|
189
|
+
def shared_store(self, mem_desc, value):
|
|
190
|
+
assert value.shape == mem_desc.shape, f"source shape {value.shape} and destination shape {mem_desc.shape} must match"
|
|
191
|
+
assert value.dtype == mem_desc.dtype, f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match"
|
|
192
|
+
self.builder.create_local_store(mem_desc.handle, value.handle)
|
|
193
|
+
|
|
194
|
+
def shared_dealloc(self, mem_desc):
|
|
195
|
+
self.builder.create_local_dealloc(mem_desc.handle)
|
|
196
|
+
|
|
197
|
+
def set_auto_layout(self, value, layout):
|
|
198
|
+
src_ty = value.type
|
|
199
|
+
assert isinstance(layout,
|
|
200
|
+
DistributedLayout), f"set_auto_layout must set to a distributed layout but got {layout}"
|
|
201
|
+
assert isinstance(src_ty.layout,
|
|
202
|
+
AutoLayout), f"set_auto_layout input must have auto layout but got {value.type.layout}"
|
|
203
|
+
handle = self.builder.create_set_auto_layout(layout._to_ir(self.builder), value.handle)
|
|
204
|
+
res_ty = ttgl.distributed_type(src_ty.element_ty, src_ty.shape, layout)
|
|
205
|
+
return self.tensor(handle, res_ty)
|
|
206
|
+
|
|
207
|
+
def memdesc_slice(self, mem_desc, start, length, dim):
|
|
208
|
+
offsets = [0] * mem_desc.rank
|
|
209
|
+
offsets[dim] = start
|
|
210
|
+
shape = list(mem_desc.shape)
|
|
211
|
+
shape[dim] = length
|
|
212
|
+
layout = mem_desc.layout
|
|
213
|
+
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
|
|
214
|
+
builder = self.builder
|
|
215
|
+
handle = builder.create_memdesc_subslice(ty.to_ir(builder), mem_desc.handle, offsets)
|
|
216
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
217
|
+
|
|
218
|
+
def memdesc_index(self, mem_desc, index):
|
|
219
|
+
shape = mem_desc.shape[1:]
|
|
220
|
+
index = self.to_tensor(index).handle
|
|
221
|
+
layout = mem_desc.layout
|
|
222
|
+
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
|
|
223
|
+
builder = self.builder
|
|
224
|
+
handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index)
|
|
225
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
226
|
+
|
|
227
|
+
def memdesc_trans(self, mem_desc, order):
|
|
228
|
+
assert len(order) == len(
|
|
229
|
+
mem_desc.shape), f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match"
|
|
230
|
+
|
|
231
|
+
shape = [mem_desc.shape[i] for i in order]
|
|
232
|
+
alloc_shape = mem_desc.type.alloc_shape
|
|
233
|
+
new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
|
|
234
|
+
new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
|
|
235
|
+
|
|
236
|
+
handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
|
|
237
|
+
layout = self.builder.get_gluon_layout_from_memdesc(handle)
|
|
238
|
+
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
|
|
239
|
+
alloc_shape=new_alloc_shape, layout=layout)
|
|
240
|
+
|
|
241
|
+
def memdesc_reshape(self, mem_desc, shape):
|
|
242
|
+
_check(
|
|
243
|
+
math.prod(shape) == math.prod(mem_desc.shape),
|
|
244
|
+
lambda: (f"memdesc_reshape total elements mismatch: "
|
|
245
|
+
f"{mem_desc.shape} -> {shape}"),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
handle = self.builder.create_memdesc_reshape(mem_desc.handle, shape)
|
|
249
|
+
layout = self.builder.get_gluon_layout_from_memdesc(handle)
|
|
250
|
+
alloc_shape = mem_desc.type.alloc_shape
|
|
251
|
+
prefix_len = len(alloc_shape) - mem_desc.rank
|
|
252
|
+
new_alloc_shape = alloc_shape[:prefix_len] + list(shape)
|
|
253
|
+
|
|
254
|
+
return ttgl.shared_memory_descriptor(
|
|
255
|
+
handle,
|
|
256
|
+
element_ty=mem_desc.dtype,
|
|
257
|
+
shape=shape,
|
|
258
|
+
alloc_shape=new_alloc_shape,
|
|
259
|
+
layout=layout,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
|
|
263
|
+
ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
|
|
264
|
+
handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle)
|
|
265
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
266
|
+
|
|
267
|
+
def wrap_tensor(self, x, scalar_ty, ret_shape, layout):
|
|
268
|
+
if ret_shape:
|
|
269
|
+
res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout)
|
|
270
|
+
else:
|
|
271
|
+
res_ty = scalar_ty
|
|
272
|
+
return self.tensor(x, res_ty)
|
|
273
|
+
|
|
274
|
+
@staticmethod
|
|
275
|
+
def _check_same_layout(xs):
|
|
276
|
+
for x in xs:
|
|
277
|
+
_check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}")
|
|
278
|
+
layouts = [x.type.layout for x in xs]
|
|
279
|
+
l0 = layouts[0]
|
|
280
|
+
_check(all(l == l0 for l in layouts[1:]),
|
|
281
|
+
lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
|
|
282
|
+
|
|
283
|
+
def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
|
|
284
|
+
reverse: bool) -> Tuple[TensorTy, ...]:
|
|
285
|
+
shape = inputs[0].type.shape
|
|
286
|
+
rank = len(shape)
|
|
287
|
+
|
|
288
|
+
assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
|
|
289
|
+
|
|
290
|
+
if axis < 0:
|
|
291
|
+
axis += rank
|
|
292
|
+
|
|
293
|
+
for t in inputs:
|
|
294
|
+
assert t.type.shape == shape, "all scan inputs must have the same shape"
|
|
295
|
+
|
|
296
|
+
scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
|
|
297
|
+
region_builder_fn(scan_op)
|
|
298
|
+
assert scan_op.verify()
|
|
299
|
+
|
|
300
|
+
return tuple(
|
|
301
|
+
self._wrap_handle_infer_layout(scan_op.get_result(i), inputs[i].type.scalar, shape)
|
|
302
|
+
for i in range(len(inputs)))
|
|
303
|
+
|
|
304
|
+
def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
|
|
305
|
+
_check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
|
|
306
|
+
# get result shape
|
|
307
|
+
shape = inputs[0].type.shape
|
|
308
|
+
rank = len(shape)
|
|
309
|
+
_check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
|
|
310
|
+
self._check_same_layout(inputs)
|
|
311
|
+
ret_shape = [s for i, s in enumerate(shape) if i != axis]
|
|
312
|
+
assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
|
|
313
|
+
|
|
314
|
+
reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
|
|
315
|
+
region_builder_fn(reduce_op)
|
|
316
|
+
assert reduce_op.verify()
|
|
317
|
+
|
|
318
|
+
return tuple(
|
|
319
|
+
self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
|
|
320
|
+
for i in range(len(inputs)))
|
|
321
|
+
|
|
322
|
+
def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> TensorTy:
|
|
323
|
+
_check(len(input.shape) == 1, lambda: "histogram only supports 1D input")
|
|
324
|
+
_check(input.dtype.is_int(), lambda: "histogram only supports integer input")
|
|
325
|
+
_check(layout is not None, lambda: "histogram requires a destination layout")
|
|
326
|
+
if mask is not None:
|
|
327
|
+
mask, input = self.broadcast_impl_value(mask, input)
|
|
328
|
+
_check(mask.type.scalar.is_bool(), lambda: "Mask must have boolean scalar type")
|
|
329
|
+
mask = mask.handle
|
|
330
|
+
layout_attr = layout._to_ir(self.builder)
|
|
331
|
+
handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr)
|
|
332
|
+
return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout)
|
|
333
|
+
|
|
334
|
+
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
|
|
335
|
+
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
|
|
336
|
+
num_partitions = len(worker_partitions)
|
|
337
|
+
assert num_partitions == len(
|
|
338
|
+
worker_num_warps
|
|
339
|
+
), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"
|
|
340
|
+
assert num_partitions == len(
|
|
341
|
+
worker_num_regs
|
|
342
|
+
), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts"
|
|
343
|
+
|
|
344
|
+
builder = self.builder
|
|
345
|
+
insert_pt = builder.get_insertion_point()
|
|
346
|
+
|
|
347
|
+
# Emit the default partition to get the result types.
|
|
348
|
+
default_block = builder.new_block()
|
|
349
|
+
builder.set_insertion_point_to_start(default_block)
|
|
350
|
+
default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
|
|
351
|
+
mlir_results = []
|
|
352
|
+
if default_results is not None:
|
|
353
|
+
mlir_results = flatten_values_to_ir(default_results)
|
|
354
|
+
builder.create_warp_yield(mlir_results)
|
|
355
|
+
result_types = [r.get_type() for r in mlir_results]
|
|
356
|
+
|
|
357
|
+
# Create the warp specialize op.
|
|
358
|
+
builder.restore_insertion_point(insert_pt)
|
|
359
|
+
mlir_args = flatten_values_to_ir(worker_args)
|
|
360
|
+
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
|
|
361
|
+
ws_op.get_default_region().push_back(default_block)
|
|
362
|
+
ws_op.set_requested_registers(worker_num_regs)
|
|
363
|
+
|
|
364
|
+
# Emit the partition regions.
|
|
365
|
+
builder.create_block_with_parent(ws_op.get_partition_op_holder(), [])
|
|
366
|
+
partitions_op = builder.create_warp_specialize_partitions(num_partitions)
|
|
367
|
+
arg_types = [arg.get_type() for arg in mlir_args]
|
|
368
|
+
for i in range(num_partitions):
|
|
369
|
+
caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
|
|
370
|
+
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
|
|
371
|
+
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
|
|
372
|
+
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
|
|
373
|
+
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context)
|
|
374
|
+
builder.create_warp_return()
|
|
375
|
+
|
|
376
|
+
builder.set_insertion_point_after(ws_op.get_operation())
|
|
377
|
+
mlir_results = [ws_op.get_result(i) for i in range(len(result_types))]
|
|
378
|
+
if default_results is None:
|
|
379
|
+
return
|
|
380
|
+
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from typing import TypeVar
|
|
2
|
+
from triton.runtime.jit import JITFunction
|
|
3
|
+
import triton.language.standard as tl_standard
|
|
4
|
+
from .._runtime import GluonJITFunction, jit
|
|
5
|
+
from triton import knobs
|
|
6
|
+
from . import _core as ttgl
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _import_from_triton(fn: JITFunction[T]) -> GluonJITFunction[T]:
|
|
12
|
+
assert knobs.runtime.interpret or isinstance(fn, JITFunction)
|
|
13
|
+
# Wrap the function and preserve its original docstring
|
|
14
|
+
gluon_fn = jit(fn.fn)
|
|
15
|
+
gluon_fn.__doc__ = fn.__doc__
|
|
16
|
+
return gluon_fn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
cdiv = _import_from_triton(tl_standard.cdiv)
|
|
20
|
+
sum = _import_from_triton(tl_standard.sum)
|
|
21
|
+
max = _import_from_triton(tl_standard.max)
|
|
22
|
+
min = _import_from_triton(tl_standard.min)
|
|
23
|
+
reduce_or = _import_from_triton(tl_standard.reduce_or)
|
|
24
|
+
xor_sum = _import_from_triton(tl_standard.xor_sum)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@jit
|
|
28
|
+
def zeros(shape, dtype, layout=None):
|
|
29
|
+
"""
|
|
30
|
+
Create a tensor filled with zeros.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
shape (Sequence[int]): The shape of the tensor.
|
|
34
|
+
dtype (dtype): The data type for the tensor.
|
|
35
|
+
layout (Optional[DistributedLayout]): The distributed layout of the tensor, defaults to AutoLayout().
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
tensor: A tensor where every element is zero.
|
|
39
|
+
"""
|
|
40
|
+
return ttgl.full(shape, 0, dtype, layout)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@jit
|
|
44
|
+
def full_like(input, value, shape=None, dtype=None, layout=None):
|
|
45
|
+
"""
|
|
46
|
+
Create a tensor with the same properties as a given tensor, filled with a specified value.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
input (tensor): Reference tensor to infer default shape, dtype, and layout.
|
|
50
|
+
value (int or float): The fill value.
|
|
51
|
+
shape (Sequence[int], optional): Target shape. Defaults to input.shape.
|
|
52
|
+
dtype (dtype, optional): Target data type. Defaults to input.dtype.
|
|
53
|
+
layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
tensor: A tensor where every element equals value.
|
|
57
|
+
"""
|
|
58
|
+
return ttgl.full(
|
|
59
|
+
input.shape if shape is None else shape,
|
|
60
|
+
value,
|
|
61
|
+
input.dtype if dtype is None else dtype,
|
|
62
|
+
input.type.layout if layout is None else layout,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@jit
|
|
67
|
+
def zeros_like(input, shape=None, dtype=None, layout=None):
|
|
68
|
+
"""
|
|
69
|
+
Create a tensor with the same properties as a given tensor, filled with zeros.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
input (tensor): Reference tensor to infer default shape, dtype, and layout.
|
|
73
|
+
shape (Sequence[int], optional): Target shape. Defaults to input.shape.
|
|
74
|
+
dtype (dtype, optional): Target data type. Defaults to input.dtype.
|
|
75
|
+
layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
tensor: A tensor where every element is zero.
|
|
79
|
+
"""
|
|
80
|
+
return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
from triton.language.core import _unwrap_if_constexpr
|
|
6
|
+
|
|
7
|
+
from triton.experimental.gluon.language._layouts import _realize_cta_layout, DistributedLayout
|
|
8
|
+
from triton.experimental.gluon import language as ttgl
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"AMDMFMALayout",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class AMDMFMALayout(DistributedLayout):
|
|
17
|
+
"""
|
|
18
|
+
Represents a layout for AMD MFMA (matrix core) operations.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
version (int): Major and minor identifier for the MFMA instruction.
|
|
22
|
+
instr_shape: (M, N) dimension for the instrinsic shape.
|
|
23
|
+
transposed (bool): indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
|
|
24
|
+
warps_per_cta (List[int]): Number of warps per CTA.
|
|
25
|
+
elem_type Optional(ttgl.dtype): Supported types are int32, fp32 and fp64. Default is fp32.
|
|
26
|
+
tiles_per_warp Optional(List[int]): Number of tiles per WARP. For mfma layout, if missing, use the default where we have unit tile size on all dimensions.
|
|
27
|
+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
|
|
28
|
+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
|
|
29
|
+
cta_order (Optional[List[int]]): CTA ordering.
|
|
30
|
+
"""
|
|
31
|
+
version: int
|
|
32
|
+
instr_shape: List[int]
|
|
33
|
+
transposed: bool
|
|
34
|
+
warps_per_cta: List[int]
|
|
35
|
+
elem_type: ttgl.dtype = ttgl.float32
|
|
36
|
+
tiles_per_warp: Optional[List[int]] = None
|
|
37
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
38
|
+
cta_split_num: Optional[List[int]] = None
|
|
39
|
+
cta_order: Optional[List[int]] = None
|
|
40
|
+
|
|
41
|
+
def __post_init__(self):
|
|
42
|
+
super().__setattr__("version", _unwrap_if_constexpr(self.version))
|
|
43
|
+
super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
|
|
44
|
+
super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
|
|
45
|
+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
|
|
46
|
+
super().__setattr__("tiles_per_warp", _unwrap_if_constexpr(self.tiles_per_warp))
|
|
47
|
+
super().__setattr__("elem_type", _unwrap_if_constexpr(self.elem_type))
|
|
48
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
49
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
50
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
51
|
+
|
|
52
|
+
if self.tiles_per_warp is None:
|
|
53
|
+
object.__setattr__(self, "tiles_per_warp", [1] * len(self.warps_per_cta))
|
|
54
|
+
|
|
55
|
+
self.verify()
|
|
56
|
+
|
|
57
|
+
def _to_ir(self, builder):
|
|
58
|
+
type = self.elem_type.to_ir(builder)
|
|
59
|
+
return builder.get_amd_mfma_layout(self.version, self.instr_shape, self.transposed, self.warps_per_cta, type,
|
|
60
|
+
self.tiles_per_warp, self.ctas_per_cga, self.cta_split_num, self.cta_order)
|
|
61
|
+
|
|
62
|
+
def mangle(self) -> str:
|
|
63
|
+
|
|
64
|
+
def stringify(x):
|
|
65
|
+
if x is None:
|
|
66
|
+
return ""
|
|
67
|
+
return "_".join(map(str, x))
|
|
68
|
+
|
|
69
|
+
return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.tiles_per_warp)}_{self.elem_type}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_MFMA"
|
|
70
|
+
|
|
71
|
+
def verify(self):
|
|
72
|
+
assert self.version >= 1 and self.version <= 4, "version must be in the [1, 4] range"
|
|
73
|
+
valid_shapes = [[32, 32], [16, 16], [64, 4], [4, 64]]
|
|
74
|
+
assert self.instr_shape in valid_shapes, "invalid intrinsic shape; accepted shapes are " + str(valid_shapes)
|
|
75
|
+
|
|
76
|
+
assert self.elem_type.is_fp32() or self.elem_type.is_fp64() \
|
|
77
|
+
or self.elem_type.is_int32() , "element type must be float32, float64, or int32"
|
|
78
|
+
|
|
79
|
+
rank = len(self.warps_per_cta)
|
|
80
|
+
_realize_cta_layout(self, rank)
|
|
81
|
+
assert len(self.ctas_per_cga) == rank
|
|
82
|
+
assert len(self.cta_split_num) == rank
|
|
83
|
+
assert len(self.cta_order) == rank
|
|
84
|
+
|
|
85
|
+
def __hash__(self):
|
|
86
|
+
return hash((
|
|
87
|
+
self.version,
|
|
88
|
+
tuple(self.instr_shape),
|
|
89
|
+
self.transposed,
|
|
90
|
+
tuple(self.warps_per_cta),
|
|
91
|
+
self.elem_type,
|
|
92
|
+
tuple(self.tiles_per_warp) if self.tiles_per_warp else None,
|
|
93
|
+
tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
|
|
94
|
+
tuple(self.cta_split_num) if self.cta_split_num else None,
|
|
95
|
+
tuple(self.cta_order) if self.cta_order else None,
|
|
96
|
+
))
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from triton import knobs
|
|
5
|
+
from triton.experimental.gluon.language import _core as ttgl
|
|
6
|
+
from triton._C.libtriton import ir
|
|
7
|
+
from ..._core import builtin, _unwrap_if_constexpr
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from ..._semantic import GluonSemantic
|
|
11
|
+
|
|
12
|
+
__all__ = ["buffer_load", "buffer_store", "mfma"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _verify_buffer_ops(ptr, offsets, mask=None, other=None):
|
|
16
|
+
assert ptr.type.is_ptr(), "ptr must be a scalar pointer type"
|
|
17
|
+
|
|
18
|
+
assert isinstance(offsets.type, ttgl.distributed_type), "expected offsets type to be a distributed_type"
|
|
19
|
+
assert offsets.dtype.is_int32() or offsets.dtype.is_uint32(), "offsets element type must be int32 or uint32"
|
|
20
|
+
|
|
21
|
+
element_type = ptr.type.scalar.element_ty
|
|
22
|
+
|
|
23
|
+
if other is not None:
|
|
24
|
+
assert mask is not None, "when other is not None, mask should not be None"
|
|
25
|
+
assert other.dtype == element_type, "other must have the same data type as ptr scalar type"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@builtin
|
|
29
|
+
def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None):
|
|
30
|
+
"""
|
|
31
|
+
AMD buffer load from global memory via a scalar base pointer and a tensor of
|
|
32
|
+
offsets instead of a tensor of pointers. This operation will load data
|
|
33
|
+
directly into registers.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
ptr (pointer to scalar): Global memory scalar base pointer to load from.
|
|
37
|
+
offsets (tensor): Offsets tensor for the load operation.
|
|
38
|
+
mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
|
|
39
|
+
other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
|
|
40
|
+
cache_modifier (str): Cache modifier specifier. Defaults to "".
|
|
41
|
+
"""
|
|
42
|
+
_verify_buffer_ops(ptr, offsets, mask, other)
|
|
43
|
+
|
|
44
|
+
mask = _unwrap_if_constexpr(mask)
|
|
45
|
+
if mask is not None:
|
|
46
|
+
offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
|
|
47
|
+
|
|
48
|
+
other = _unwrap_if_constexpr(other)
|
|
49
|
+
if other is not None:
|
|
50
|
+
offsets, other = _semantic.broadcast_impl_value(offsets, other)
|
|
51
|
+
|
|
52
|
+
other = other.handle if other is not None else ir.value()
|
|
53
|
+
mask = mask.handle if mask is not None else ir.value()
|
|
54
|
+
cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
|
|
55
|
+
|
|
56
|
+
ret_ty = offsets.type.with_element_ty(ptr.type.scalar.element_ty)
|
|
57
|
+
builder = _semantic.builder
|
|
58
|
+
handle = builder.create_buffer_load(ret_ty.to_ir(builder), ptr.handle, offsets.handle, mask, other, cache_modifier)
|
|
59
|
+
return ttgl.tensor(handle, ret_ty)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@builtin
|
|
63
|
+
def buffer_store(stored_value, ptr, offsets, mask=None, cache=None, _semantic: GluonSemantic = None):
|
|
64
|
+
"""
|
|
65
|
+
AMD buffer store a tensor directly to global memory via a scalar base pointer and a tensor of
|
|
66
|
+
offsets instead of a tensor of pointers.
|
|
67
|
+
Args:
|
|
68
|
+
stored_value (tensor to be stored): The tensor to be stored to global memory.
|
|
69
|
+
ptr (pointer to scalar): Global memory scalar base pointer to store to.
|
|
70
|
+
offsets (tensor): Offsets tensor for the store operation.
|
|
71
|
+
mask (tensor, optional): Mask tensor for predicated store. Defaults to None.
|
|
72
|
+
cache_modifier (str): Cache modifier specifier. Defaults to "".
|
|
73
|
+
"""
|
|
74
|
+
_verify_buffer_ops(ptr, offsets, mask)
|
|
75
|
+
|
|
76
|
+
if mask is not None:
|
|
77
|
+
offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
|
|
78
|
+
|
|
79
|
+
mask = mask.handle if mask is not None else ir.value()
|
|
80
|
+
cache_modifier = _semantic._str_to_store_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
|
|
81
|
+
|
|
82
|
+
_semantic.builder.create_buffer_store(stored_value.handle, ptr.handle, offsets.handle, mask, cache_modifier)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@builtin
|
|
86
|
+
def mfma(a, b, acc, _semantic: GluonSemantic = None):
|
|
87
|
+
"""
|
|
88
|
+
Computes matrix-multiplication of a * b + acc using AMD native matrix core units.
|
|
89
|
+
Args:
|
|
90
|
+
a (tensor): The first operand of mfma.
|
|
91
|
+
b (tensor): The second operand of mfma.
|
|
92
|
+
acc (tensor): The accumulator tensor.
|
|
93
|
+
"""
|
|
94
|
+
assert acc is not None, "acc is required"
|
|
95
|
+
ret_type = acc.type
|
|
96
|
+
acc = ttgl._unwrap_if_constexpr(acc)
|
|
97
|
+
|
|
98
|
+
handle = _semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None,
|
|
99
|
+
out_dtype=acc.dtype).handle
|
|
100
|
+
return ttgl.tensor(handle, ret_type)
|