triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from typing import Sequence, List, TypeVar, Tuple, Callable
|
|
2
|
+
import math
|
|
2
3
|
from triton.language.semantic import TritonSemantic
|
|
3
4
|
from . import _core as ttgl
|
|
4
|
-
from ._layouts import SliceLayout
|
|
5
|
+
from ._layouts import AutoLayout, DistributedLayout, SliceLayout
|
|
5
6
|
from triton._C.libtriton.gluon_ir import GluonOpBuilder
|
|
6
7
|
from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
|
|
7
8
|
|
|
@@ -13,6 +14,18 @@ def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
|
|
|
13
14
|
raise category(msg_fn())
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
class GluonCallerContext:
|
|
18
|
+
|
|
19
|
+
def __init__(self, num_warps: int):
|
|
20
|
+
self.num_warps = num_warps
|
|
21
|
+
|
|
22
|
+
def mangle(self):
|
|
23
|
+
return f"_NW{self.num_warps}"
|
|
24
|
+
|
|
25
|
+
def initialize_callee(self, fn, builder):
|
|
26
|
+
fn.set_attr("ttg.num-warps", builder.get_int32_attr(self.num_warps))
|
|
27
|
+
|
|
28
|
+
|
|
16
29
|
class GluonSemantic(TritonSemantic[TensorTy]):
|
|
17
30
|
tensor = ttgl.tensor
|
|
18
31
|
lang = ttgl
|
|
@@ -22,10 +35,15 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
22
35
|
def __init__(self, builder: GluonOpBuilder):
|
|
23
36
|
self.builder = builder
|
|
24
37
|
|
|
38
|
+
def _wrap_handle_infer_layout(self, handle, scalar_ty, shape):
|
|
39
|
+
if shape == []:
|
|
40
|
+
ty = scalar_ty
|
|
41
|
+
else:
|
|
42
|
+
ty = ttgl.distributed_type(scalar_ty, shape, self.builder.get_gluon_layout_from_tensor(handle))
|
|
43
|
+
return self.tensor(handle, ty)
|
|
44
|
+
|
|
25
45
|
def _wrap_tensor_infer_layout(self, tensor):
|
|
26
|
-
|
|
27
|
-
self.builder.get_gluon_layout_from_tensor(tensor.handle))
|
|
28
|
-
return self.tensor(tensor.handle, ty)
|
|
46
|
+
return self._wrap_handle_infer_layout(tensor.handle, tensor.type.scalar, tensor.shape)
|
|
29
47
|
|
|
30
48
|
def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
|
|
31
49
|
if len(lhs_shape) != len(rhs_shape):
|
|
@@ -53,14 +71,14 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
53
71
|
_check(isinstance(input.type, ttgl.distributed_type),
|
|
54
72
|
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
|
|
55
73
|
layout = input.type.layout
|
|
56
|
-
_check(isinstance(layout, SliceLayout),
|
|
74
|
+
_check(isinstance(layout, (SliceLayout, AutoLayout)),
|
|
57
75
|
lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
|
|
58
|
-
_check(
|
|
59
|
-
|
|
76
|
+
_check(
|
|
77
|
+
isinstance(layout, AutoLayout) or layout.dim == axis,
|
|
78
|
+
lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
|
|
60
79
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
return self.tensor(handle, ret_ty)
|
|
80
|
+
handle = self.builder.create_expand_dims(input.handle, axis)
|
|
81
|
+
return self._wrap_handle_infer_layout(handle, input.type.scalar, dst_shape)
|
|
64
82
|
|
|
65
83
|
def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
|
|
66
84
|
a, b = self.broadcast_impl_value(a, b)
|
|
@@ -107,7 +125,14 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
107
125
|
lhs_shape = lhs_ty.get_block_shapes()
|
|
108
126
|
rhs_shape = rhs_ty.get_block_shapes()
|
|
109
127
|
ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
|
|
110
|
-
|
|
128
|
+
|
|
129
|
+
is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout)
|
|
130
|
+
is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout)
|
|
131
|
+
if is_lhs_auto and not is_rhs_auto:
|
|
132
|
+
lhs = self.set_auto_layout(lhs, rhs_ty.layout)
|
|
133
|
+
elif is_rhs_auto and not is_lhs_auto:
|
|
134
|
+
rhs = self.set_auto_layout(rhs, lhs_ty.layout)
|
|
135
|
+
elif lhs_ty.layout != rhs_ty.layout:
|
|
111
136
|
raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
|
|
112
137
|
|
|
113
138
|
lhs = self.broadcast_impl_shape(lhs, ret_shape)
|
|
@@ -116,6 +141,8 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
116
141
|
|
|
117
142
|
def arange(self, start, end, layout):
|
|
118
143
|
shape = [end - start]
|
|
144
|
+
if layout is None:
|
|
145
|
+
layout = AutoLayout()
|
|
119
146
|
ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
|
|
120
147
|
return super().arange(start, end, ret_ty=ret_ty)
|
|
121
148
|
|
|
@@ -131,14 +158,19 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
131
158
|
|
|
132
159
|
def full(self, shape, value, dtype, layout):
|
|
133
160
|
scalar = self.make_scalar(value, dtype)
|
|
161
|
+
if layout is None:
|
|
162
|
+
layout = AutoLayout()
|
|
134
163
|
return self.splat(scalar, shape, layout)
|
|
135
164
|
|
|
136
|
-
def convert_layout(self, value, layout):
|
|
165
|
+
def convert_layout(self, value, layout, assert_trivial=False):
|
|
137
166
|
ty = value.type
|
|
138
167
|
_check(isinstance(ty, ttgl.distributed_type),
|
|
139
168
|
lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
|
|
140
169
|
ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
|
|
141
|
-
|
|
170
|
+
ret_ty_ir = ret_ty.to_ir(self.builder)
|
|
171
|
+
if assert_trivial and not self.builder.is_convert_layout_trivial(ret_ty_ir, value.handle):
|
|
172
|
+
raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial")
|
|
173
|
+
handle = self.builder.create_convert_layout(ret_ty_ir, value.handle)
|
|
142
174
|
return ttgl.tensor(handle, ret_ty)
|
|
143
175
|
|
|
144
176
|
def allocate_shared(self, element_ty, shape, layout, value):
|
|
@@ -155,30 +187,42 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
155
187
|
return ttgl.tensor(handle, ret_ty)
|
|
156
188
|
|
|
157
189
|
def shared_store(self, mem_desc, value):
|
|
190
|
+
assert value.shape == mem_desc.shape, f"source shape {value.shape} and destination shape {mem_desc.shape} must match"
|
|
191
|
+
assert value.dtype == mem_desc.dtype, f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match"
|
|
158
192
|
self.builder.create_local_store(mem_desc.handle, value.handle)
|
|
159
193
|
|
|
160
194
|
def shared_dealloc(self, mem_desc):
|
|
161
195
|
self.builder.create_local_dealloc(mem_desc.handle)
|
|
162
196
|
|
|
163
|
-
def
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
197
|
+
def set_auto_layout(self, value, layout):
|
|
198
|
+
src_ty = value.type
|
|
199
|
+
assert isinstance(layout,
|
|
200
|
+
DistributedLayout), f"set_auto_layout must set to a distributed layout but got {layout}"
|
|
201
|
+
assert isinstance(src_ty.layout,
|
|
202
|
+
AutoLayout), f"set_auto_layout input must have auto layout but got {value.type.layout}"
|
|
203
|
+
handle = self.builder.create_set_auto_layout(layout._to_ir(self.builder), value.handle)
|
|
204
|
+
res_ty = ttgl.distributed_type(src_ty.element_ty, src_ty.shape, layout)
|
|
205
|
+
return self.tensor(handle, res_ty)
|
|
169
206
|
|
|
170
207
|
def memdesc_slice(self, mem_desc, start, length, dim):
|
|
171
|
-
offsets = [
|
|
172
|
-
offsets[dim] =
|
|
208
|
+
offsets = [0] * mem_desc.rank
|
|
209
|
+
offsets[dim] = start
|
|
173
210
|
shape = list(mem_desc.shape)
|
|
174
211
|
shape[dim] = length
|
|
175
|
-
|
|
212
|
+
layout = mem_desc.layout
|
|
213
|
+
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
|
|
214
|
+
builder = self.builder
|
|
215
|
+
handle = builder.create_memdesc_subslice(ty.to_ir(builder), mem_desc.handle, offsets)
|
|
216
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
176
217
|
|
|
177
218
|
def memdesc_index(self, mem_desc, index):
|
|
178
219
|
shape = mem_desc.shape[1:]
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
220
|
+
index = self.to_tensor(index).handle
|
|
221
|
+
layout = mem_desc.layout
|
|
222
|
+
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
|
|
223
|
+
builder = self.builder
|
|
224
|
+
handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index)
|
|
225
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
182
226
|
|
|
183
227
|
def memdesc_trans(self, mem_desc, order):
|
|
184
228
|
assert len(order) == len(
|
|
@@ -194,10 +238,26 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
194
238
|
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
|
|
195
239
|
alloc_shape=new_alloc_shape, layout=layout)
|
|
196
240
|
|
|
197
|
-
def memdesc_reshape(self, mem_desc, shape
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
241
|
+
def memdesc_reshape(self, mem_desc, shape):
|
|
242
|
+
_check(
|
|
243
|
+
math.prod(shape) == math.prod(mem_desc.shape),
|
|
244
|
+
lambda: (f"memdesc_reshape total elements mismatch: "
|
|
245
|
+
f"{mem_desc.shape} -> {shape}"),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
handle = self.builder.create_memdesc_reshape(mem_desc.handle, shape)
|
|
249
|
+
layout = self.builder.get_gluon_layout_from_memdesc(handle)
|
|
250
|
+
alloc_shape = mem_desc.type.alloc_shape
|
|
251
|
+
prefix_len = len(alloc_shape) - mem_desc.rank
|
|
252
|
+
new_alloc_shape = alloc_shape[:prefix_len] + list(shape)
|
|
253
|
+
|
|
254
|
+
return ttgl.shared_memory_descriptor(
|
|
255
|
+
handle,
|
|
256
|
+
element_ty=mem_desc.dtype,
|
|
257
|
+
shape=shape,
|
|
258
|
+
alloc_shape=new_alloc_shape,
|
|
259
|
+
layout=layout,
|
|
260
|
+
)
|
|
201
261
|
|
|
202
262
|
def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
|
|
203
263
|
ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
|
|
@@ -220,6 +280,27 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
220
280
|
_check(all(l == l0 for l in layouts[1:]),
|
|
221
281
|
lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
|
|
222
282
|
|
|
283
|
+
def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
|
|
284
|
+
reverse: bool) -> Tuple[TensorTy, ...]:
|
|
285
|
+
shape = inputs[0].type.shape
|
|
286
|
+
rank = len(shape)
|
|
287
|
+
|
|
288
|
+
assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
|
|
289
|
+
|
|
290
|
+
if axis < 0:
|
|
291
|
+
axis += rank
|
|
292
|
+
|
|
293
|
+
for t in inputs:
|
|
294
|
+
assert t.type.shape == shape, "all scan inputs must have the same shape"
|
|
295
|
+
|
|
296
|
+
scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
|
|
297
|
+
region_builder_fn(scan_op)
|
|
298
|
+
assert scan_op.verify()
|
|
299
|
+
|
|
300
|
+
return tuple(
|
|
301
|
+
self._wrap_handle_infer_layout(scan_op.get_result(i), inputs[i].type.scalar, shape)
|
|
302
|
+
for i in range(len(inputs)))
|
|
303
|
+
|
|
223
304
|
def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
|
|
224
305
|
_check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
|
|
225
306
|
# get result shape
|
|
@@ -228,7 +309,6 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
228
309
|
_check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
|
|
229
310
|
self._check_same_layout(inputs)
|
|
230
311
|
ret_shape = [s for i, s in enumerate(shape) if i != axis]
|
|
231
|
-
ret_layout = SliceLayout(axis, inputs[0].type.layout)
|
|
232
312
|
assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
|
|
233
313
|
|
|
234
314
|
reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
|
|
@@ -236,11 +316,23 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
236
316
|
assert reduce_op.verify()
|
|
237
317
|
|
|
238
318
|
return tuple(
|
|
239
|
-
self.
|
|
319
|
+
self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
|
|
240
320
|
for i in range(len(inputs)))
|
|
241
321
|
|
|
242
|
-
def
|
|
243
|
-
|
|
322
|
+
def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> TensorTy:
|
|
323
|
+
_check(len(input.shape) == 1, lambda: "histogram only supports 1D input")
|
|
324
|
+
_check(input.dtype.is_int(), lambda: "histogram only supports integer input")
|
|
325
|
+
_check(layout is not None, lambda: "histogram requires a destination layout")
|
|
326
|
+
if mask is not None:
|
|
327
|
+
mask, input = self.broadcast_impl_value(mask, input)
|
|
328
|
+
_check(mask.type.scalar.is_bool(), lambda: "Mask must have boolean scalar type")
|
|
329
|
+
mask = mask.handle
|
|
330
|
+
layout_attr = layout._to_ir(self.builder)
|
|
331
|
+
handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr)
|
|
332
|
+
return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout)
|
|
333
|
+
|
|
334
|
+
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
|
|
335
|
+
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
|
|
244
336
|
num_partitions = len(worker_partitions)
|
|
245
337
|
assert num_partitions == len(
|
|
246
338
|
worker_num_warps
|
|
@@ -255,7 +347,7 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
255
347
|
# Emit the default partition to get the result types.
|
|
256
348
|
default_block = builder.new_block()
|
|
257
349
|
builder.set_insertion_point_to_start(default_block)
|
|
258
|
-
default_results = generator.call_JitFunction(default_partition,
|
|
350
|
+
default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
|
|
259
351
|
mlir_results = []
|
|
260
352
|
if default_results is not None:
|
|
261
353
|
mlir_results = flatten_values_to_ir(default_results)
|
|
@@ -264,7 +356,7 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
264
356
|
|
|
265
357
|
# Create the warp specialize op.
|
|
266
358
|
builder.restore_insertion_point(insert_pt)
|
|
267
|
-
mlir_args = flatten_values_to_ir(
|
|
359
|
+
mlir_args = flatten_values_to_ir(worker_args)
|
|
268
360
|
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
|
|
269
361
|
ws_op.get_default_region().push_back(default_block)
|
|
270
362
|
ws_op.set_requested_registers(worker_num_regs)
|
|
@@ -274,10 +366,11 @@ class GluonSemantic(TritonSemantic[TensorTy]):
|
|
|
274
366
|
partitions_op = builder.create_warp_specialize_partitions(num_partitions)
|
|
275
367
|
arg_types = [arg.get_type() for arg in mlir_args]
|
|
276
368
|
for i in range(num_partitions):
|
|
369
|
+
caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
|
|
277
370
|
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
|
|
278
371
|
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
|
|
279
|
-
block_args = unflatten_ir_values(block_args, [arg.type for arg in
|
|
280
|
-
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
|
|
372
|
+
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
|
|
373
|
+
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context)
|
|
281
374
|
builder.create_warp_return()
|
|
282
375
|
|
|
283
376
|
builder.set_insertion_point_after(ws_op.get_operation())
|
|
@@ -1,39 +1,60 @@
|
|
|
1
|
-
|
|
2
|
-
import
|
|
1
|
+
from typing import TypeVar
|
|
2
|
+
from triton.runtime.jit import JITFunction
|
|
3
3
|
import triton.language.standard as tl_standard
|
|
4
|
-
from .._runtime import jit
|
|
4
|
+
from .._runtime import GluonJITFunction, jit
|
|
5
5
|
from triton import knobs
|
|
6
6
|
from . import _core as ttgl
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
"sum",
|
|
10
|
-
"max",
|
|
11
|
-
"min",
|
|
12
|
-
"reduce_or",
|
|
13
|
-
"xor_sum",
|
|
14
|
-
]
|
|
8
|
+
T = TypeVar("T")
|
|
15
9
|
|
|
16
|
-
__all__ = [
|
|
17
|
-
"full_like",
|
|
18
|
-
"zeros",
|
|
19
|
-
"zeros_like",
|
|
20
|
-
*_IMPORT_FROM_TRITON,
|
|
21
|
-
]
|
|
22
10
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
11
|
+
def _import_from_triton(fn: JITFunction[T]) -> GluonJITFunction[T]:
|
|
12
|
+
assert knobs.runtime.interpret or isinstance(fn, JITFunction)
|
|
13
|
+
# Wrap the function and preserve its original docstring
|
|
14
|
+
gluon_fn = jit(fn.fn)
|
|
15
|
+
gluon_fn.__doc__ = fn.__doc__
|
|
16
|
+
return gluon_fn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
cdiv = _import_from_triton(tl_standard.cdiv)
|
|
20
|
+
sum = _import_from_triton(tl_standard.sum)
|
|
21
|
+
max = _import_from_triton(tl_standard.max)
|
|
22
|
+
min = _import_from_triton(tl_standard.min)
|
|
23
|
+
reduce_or = _import_from_triton(tl_standard.reduce_or)
|
|
24
|
+
xor_sum = _import_from_triton(tl_standard.xor_sum)
|
|
28
25
|
|
|
29
26
|
|
|
30
27
|
@jit
|
|
31
|
-
def zeros(shape, dtype, layout):
|
|
28
|
+
def zeros(shape, dtype, layout=None):
|
|
29
|
+
"""
|
|
30
|
+
Create a tensor filled with zeros.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
shape (Sequence[int]): The shape of the tensor.
|
|
34
|
+
dtype (dtype): The data type for the tensor.
|
|
35
|
+
layout (Optional[DistributedLayout]): The distributed layout of the tensor, defaults to AutoLayout().
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
tensor: A tensor where every element is zero.
|
|
39
|
+
"""
|
|
32
40
|
return ttgl.full(shape, 0, dtype, layout)
|
|
33
41
|
|
|
34
42
|
|
|
35
43
|
@jit
|
|
36
44
|
def full_like(input, value, shape=None, dtype=None, layout=None):
|
|
45
|
+
"""
|
|
46
|
+
Create a tensor with the same properties as a given tensor, filled with a specified value.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
input (tensor): Reference tensor to infer default shape, dtype, and layout.
|
|
50
|
+
value (int or float): The fill value.
|
|
51
|
+
shape (Sequence[int], optional): Target shape. Defaults to input.shape.
|
|
52
|
+
dtype (dtype, optional): Target data type. Defaults to input.dtype.
|
|
53
|
+
layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
tensor: A tensor where every element equals value.
|
|
57
|
+
"""
|
|
37
58
|
return ttgl.full(
|
|
38
59
|
input.shape if shape is None else shape,
|
|
39
60
|
value,
|
|
@@ -44,4 +65,16 @@ def full_like(input, value, shape=None, dtype=None, layout=None):
|
|
|
44
65
|
|
|
45
66
|
@jit
|
|
46
67
|
def zeros_like(input, shape=None, dtype=None, layout=None):
|
|
68
|
+
"""
|
|
69
|
+
Create a tensor with the same properties as a given tensor, filled with zeros.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
input (tensor): Reference tensor to infer default shape, dtype, and layout.
|
|
73
|
+
shape (Sequence[int], optional): Target shape. Defaults to input.shape.
|
|
74
|
+
dtype (dtype, optional): Target data type. Defaults to input.dtype.
|
|
75
|
+
layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
tensor: A tensor where every element is zero.
|
|
79
|
+
"""
|
|
47
80
|
return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
from triton.language.core import _unwrap_if_constexpr
|
|
6
|
+
|
|
7
|
+
from triton.experimental.gluon.language._layouts import _realize_cta_layout, DistributedLayout
|
|
8
|
+
from triton.experimental.gluon import language as ttgl
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"AMDMFMALayout",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class AMDMFMALayout(DistributedLayout):
|
|
17
|
+
"""
|
|
18
|
+
Represents a layout for AMD MFMA (matrix core) operations.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
version (int): Major and minor identifier for the MFMA instruction.
|
|
22
|
+
instr_shape: (M, N) dimension for the instrinsic shape.
|
|
23
|
+
transposed (bool): indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
|
|
24
|
+
warps_per_cta (List[int]): Number of warps per CTA.
|
|
25
|
+
elem_type Optional(ttgl.dtype): Supported types are int32, fp32 and fp64. Default is fp32.
|
|
26
|
+
tiles_per_warp Optional(List[int]): Number of tiles per WARP. For mfma layout, if missing, use the default where we have unit tile size on all dimensions.
|
|
27
|
+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
|
|
28
|
+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
|
|
29
|
+
cta_order (Optional[List[int]]): CTA ordering.
|
|
30
|
+
"""
|
|
31
|
+
version: int
|
|
32
|
+
instr_shape: List[int]
|
|
33
|
+
transposed: bool
|
|
34
|
+
warps_per_cta: List[int]
|
|
35
|
+
elem_type: ttgl.dtype = ttgl.float32
|
|
36
|
+
tiles_per_warp: Optional[List[int]] = None
|
|
37
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
38
|
+
cta_split_num: Optional[List[int]] = None
|
|
39
|
+
cta_order: Optional[List[int]] = None
|
|
40
|
+
|
|
41
|
+
def __post_init__(self):
|
|
42
|
+
super().__setattr__("version", _unwrap_if_constexpr(self.version))
|
|
43
|
+
super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
|
|
44
|
+
super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
|
|
45
|
+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
|
|
46
|
+
super().__setattr__("tiles_per_warp", _unwrap_if_constexpr(self.tiles_per_warp))
|
|
47
|
+
super().__setattr__("elem_type", _unwrap_if_constexpr(self.elem_type))
|
|
48
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
49
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
50
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
51
|
+
|
|
52
|
+
if self.tiles_per_warp is None:
|
|
53
|
+
object.__setattr__(self, "tiles_per_warp", [1] * len(self.warps_per_cta))
|
|
54
|
+
|
|
55
|
+
self.verify()
|
|
56
|
+
|
|
57
|
+
def _to_ir(self, builder):
|
|
58
|
+
type = self.elem_type.to_ir(builder)
|
|
59
|
+
return builder.get_amd_mfma_layout(self.version, self.instr_shape, self.transposed, self.warps_per_cta, type,
|
|
60
|
+
self.tiles_per_warp, self.ctas_per_cga, self.cta_split_num, self.cta_order)
|
|
61
|
+
|
|
62
|
+
def mangle(self) -> str:
|
|
63
|
+
|
|
64
|
+
def stringify(x):
|
|
65
|
+
if x is None:
|
|
66
|
+
return ""
|
|
67
|
+
return "_".join(map(str, x))
|
|
68
|
+
|
|
69
|
+
return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.tiles_per_warp)}_{self.elem_type}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_MFMA"
|
|
70
|
+
|
|
71
|
+
def verify(self):
|
|
72
|
+
assert self.version >= 1 and self.version <= 4, "version must be in the [1, 4] range"
|
|
73
|
+
valid_shapes = [[32, 32], [16, 16], [64, 4], [4, 64]]
|
|
74
|
+
assert self.instr_shape in valid_shapes, "invalid intrinsic shape; accepted shapes are " + str(valid_shapes)
|
|
75
|
+
|
|
76
|
+
assert self.elem_type.is_fp32() or self.elem_type.is_fp64() \
|
|
77
|
+
or self.elem_type.is_int32() , "element type must be float32, float64, or int32"
|
|
78
|
+
|
|
79
|
+
rank = len(self.warps_per_cta)
|
|
80
|
+
_realize_cta_layout(self, rank)
|
|
81
|
+
assert len(self.ctas_per_cga) == rank
|
|
82
|
+
assert len(self.cta_split_num) == rank
|
|
83
|
+
assert len(self.cta_order) == rank
|
|
84
|
+
|
|
85
|
+
def __hash__(self):
|
|
86
|
+
return hash((
|
|
87
|
+
self.version,
|
|
88
|
+
tuple(self.instr_shape),
|
|
89
|
+
self.transposed,
|
|
90
|
+
tuple(self.warps_per_cta),
|
|
91
|
+
self.elem_type,
|
|
92
|
+
tuple(self.tiles_per_warp) if self.tiles_per_warp else None,
|
|
93
|
+
tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
|
|
94
|
+
tuple(self.cta_split_num) if self.cta_split_num else None,
|
|
95
|
+
tuple(self.cta_order) if self.cta_order else None,
|
|
96
|
+
))
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from triton import knobs
|
|
5
|
+
from triton.experimental.gluon.language import _core as ttgl
|
|
6
|
+
from triton._C.libtriton import ir
|
|
7
|
+
from ..._core import builtin, _unwrap_if_constexpr
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from ..._semantic import GluonSemantic
|
|
11
|
+
|
|
12
|
+
__all__ = ["buffer_load", "buffer_store", "mfma"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _verify_buffer_ops(ptr, offsets, mask=None, other=None):
|
|
16
|
+
assert ptr.type.is_ptr(), "ptr must be a scalar pointer type"
|
|
17
|
+
|
|
18
|
+
assert isinstance(offsets.type, ttgl.distributed_type), "expected offsets type to be a distributed_type"
|
|
19
|
+
assert offsets.dtype.is_int32() or offsets.dtype.is_uint32(), "offsets element type must be int32 or uint32"
|
|
20
|
+
|
|
21
|
+
element_type = ptr.type.scalar.element_ty
|
|
22
|
+
|
|
23
|
+
if other is not None:
|
|
24
|
+
assert mask is not None, "when other is not None, mask should not be None"
|
|
25
|
+
assert other.dtype == element_type, "other must have the same data type as ptr scalar type"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@builtin
|
|
29
|
+
def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None):
|
|
30
|
+
"""
|
|
31
|
+
AMD buffer load from global memory via a scalar base pointer and a tensor of
|
|
32
|
+
offsets instead of a tensor of pointers. This operation will load data
|
|
33
|
+
directly into registers.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
ptr (pointer to scalar): Global memory scalar base pointer to load from.
|
|
37
|
+
offsets (tensor): Offsets tensor for the load operation.
|
|
38
|
+
mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
|
|
39
|
+
other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
|
|
40
|
+
cache_modifier (str): Cache modifier specifier. Defaults to "".
|
|
41
|
+
"""
|
|
42
|
+
_verify_buffer_ops(ptr, offsets, mask, other)
|
|
43
|
+
|
|
44
|
+
mask = _unwrap_if_constexpr(mask)
|
|
45
|
+
if mask is not None:
|
|
46
|
+
offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
|
|
47
|
+
|
|
48
|
+
other = _unwrap_if_constexpr(other)
|
|
49
|
+
if other is not None:
|
|
50
|
+
offsets, other = _semantic.broadcast_impl_value(offsets, other)
|
|
51
|
+
|
|
52
|
+
other = other.handle if other is not None else ir.value()
|
|
53
|
+
mask = mask.handle if mask is not None else ir.value()
|
|
54
|
+
cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
|
|
55
|
+
|
|
56
|
+
ret_ty = offsets.type.with_element_ty(ptr.type.scalar.element_ty)
|
|
57
|
+
builder = _semantic.builder
|
|
58
|
+
handle = builder.create_buffer_load(ret_ty.to_ir(builder), ptr.handle, offsets.handle, mask, other, cache_modifier)
|
|
59
|
+
return ttgl.tensor(handle, ret_ty)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@builtin
|
|
63
|
+
def buffer_store(stored_value, ptr, offsets, mask=None, cache=None, _semantic: GluonSemantic = None):
|
|
64
|
+
"""
|
|
65
|
+
AMD buffer store a tensor directly to global memory via a scalar base pointer and a tensor of
|
|
66
|
+
offsets instead of a tensor of pointers.
|
|
67
|
+
Args:
|
|
68
|
+
stored_value (tensor to be stored): The tensor to be stored to global memory.
|
|
69
|
+
ptr (pointer to scalar): Global memory scalar base pointer to store to.
|
|
70
|
+
offsets (tensor): Offsets tensor for the store operation.
|
|
71
|
+
mask (tensor, optional): Mask tensor for predicated store. Defaults to None.
|
|
72
|
+
cache_modifier (str): Cache modifier specifier. Defaults to "".
|
|
73
|
+
"""
|
|
74
|
+
_verify_buffer_ops(ptr, offsets, mask)
|
|
75
|
+
|
|
76
|
+
if mask is not None:
|
|
77
|
+
offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
|
|
78
|
+
|
|
79
|
+
mask = mask.handle if mask is not None else ir.value()
|
|
80
|
+
cache_modifier = _semantic._str_to_store_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
|
|
81
|
+
|
|
82
|
+
_semantic.builder.create_buffer_store(stored_value.handle, ptr.handle, offsets.handle, mask, cache_modifier)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@builtin
|
|
86
|
+
def mfma(a, b, acc, _semantic: GluonSemantic = None):
|
|
87
|
+
"""
|
|
88
|
+
Computes matrix-multiplication of a * b + acc using AMD native matrix core units.
|
|
89
|
+
Args:
|
|
90
|
+
a (tensor): The first operand of mfma.
|
|
91
|
+
b (tensor): The second operand of mfma.
|
|
92
|
+
acc (tensor): The accumulator tensor.
|
|
93
|
+
"""
|
|
94
|
+
assert acc is not None, "acc is required"
|
|
95
|
+
ret_type = acc.type
|
|
96
|
+
acc = ttgl._unwrap_if_constexpr(acc)
|
|
97
|
+
|
|
98
|
+
handle = _semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None,
|
|
99
|
+
out_dtype=acc.dtype).handle
|
|
100
|
+
return ttgl.tensor(handle, ret_type)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from triton.experimental.gluon.language import _core as ttgl
|
|
2
|
+
from ..._core import builtin, float32
|
|
3
|
+
from ..._layouts import DotOperandLayout
|
|
4
|
+
from .._layouts import AMDMFMALayout
|
|
5
|
+
from ..cdna3 import * # NOQA: F403
|
|
6
|
+
from ..cdna3 import __all__ as __cdna3_all
|
|
7
|
+
from . import async_copy
|
|
8
|
+
|
|
9
|
+
__all__ = [*__cdna3_all, "async_copy", "mfma_scaled"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@builtin
|
|
13
|
+
def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None):
|
|
14
|
+
"""
|
|
15
|
+
AMD Scaled MFMA operation.
|
|
16
|
+
|
|
17
|
+
```
|
|
18
|
+
c = a * a_scale @ b * b_scale + acc
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
`a` and `b` use microscaling formats described in
|
|
22
|
+
"OCP Microscaling Formats (MX) Specification":
|
|
23
|
+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf.
|
|
24
|
+
Currently supported only on CDNA4 hardware.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
a (tensor): The operand A to be multiplied.
|
|
28
|
+
a_scale (tensor): Scale factor for operand A.
|
|
29
|
+
a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`.
|
|
30
|
+
b (tensor): The operand B to be multiplied.
|
|
31
|
+
b_scale (tensor): Scale factor for operand B. Available formats: `e2m1`, `e4m3`, `e5m2`.
|
|
32
|
+
b_format (str): Format of the operand B.
|
|
33
|
+
acc (tensor): Accumulator tensor.
|
|
34
|
+
"""
|
|
35
|
+
layout = acc.type.layout
|
|
36
|
+
assert isinstance(layout, AMDMFMALayout), "Expected layout to be an instance of AMDMFMALayout"
|
|
37
|
+
assert (isinstance(a.type.layout, DotOperandLayout) and a.type.layout.parent== layout), \
|
|
38
|
+
"Expected lhs layout to be a DotOperandLayout with parent matching MFMA layout"
|
|
39
|
+
assert (isinstance(b.type.layout, DotOperandLayout) and b.type.layout.parent == layout), \
|
|
40
|
+
"Expected rhs layout to be a DotOperandLayout with parent matching MFMA layout"
|
|
41
|
+
|
|
42
|
+
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
|
|
43
|
+
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
|
|
44
|
+
|
|
45
|
+
tensor = _semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, False, True, True, float32)
|
|
46
|
+
|
|
47
|
+
ret_ty = ttgl.distributed_type(tensor.dtype, tensor.shape, layout)
|
|
48
|
+
return ttgl.tensor(tensor.handle, ret_ty)
|