triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  from typing import Sequence, List, TypeVar, Tuple, Callable
2
+ import math
2
3
  from triton.language.semantic import TritonSemantic
3
4
  from . import _core as ttgl
4
- from ._layouts import SliceLayout
5
+ from ._layouts import AutoLayout, DistributedLayout, SliceLayout
5
6
  from triton._C.libtriton.gluon_ir import GluonOpBuilder
6
7
  from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
7
8
 
@@ -13,6 +14,18 @@ def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
13
14
  raise category(msg_fn())
14
15
 
15
16
 
17
+ class GluonCallerContext:
18
+
19
+ def __init__(self, num_warps: int):
20
+ self.num_warps = num_warps
21
+
22
+ def mangle(self):
23
+ return f"_NW{self.num_warps}"
24
+
25
+ def initialize_callee(self, fn, builder):
26
+ fn.set_attr("ttg.num-warps", builder.get_int32_attr(self.num_warps))
27
+
28
+
16
29
  class GluonSemantic(TritonSemantic[TensorTy]):
17
30
  tensor = ttgl.tensor
18
31
  lang = ttgl
@@ -22,10 +35,15 @@ class GluonSemantic(TritonSemantic[TensorTy]):
22
35
  def __init__(self, builder: GluonOpBuilder):
23
36
  self.builder = builder
24
37
 
38
+ def _wrap_handle_infer_layout(self, handle, scalar_ty, shape):
39
+ if shape == []:
40
+ ty = scalar_ty
41
+ else:
42
+ ty = ttgl.distributed_type(scalar_ty, shape, self.builder.get_gluon_layout_from_tensor(handle))
43
+ return self.tensor(handle, ty)
44
+
25
45
  def _wrap_tensor_infer_layout(self, tensor):
26
- ty = ttgl.distributed_type(tensor.type.scalar, tensor.shape,
27
- self.builder.get_gluon_layout_from_tensor(tensor.handle))
28
- return self.tensor(tensor.handle, ty)
46
+ return self._wrap_handle_infer_layout(tensor.handle, tensor.type.scalar, tensor.shape)
29
47
 
30
48
  def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
31
49
  if len(lhs_shape) != len(rhs_shape):
@@ -53,14 +71,14 @@ class GluonSemantic(TritonSemantic[TensorTy]):
53
71
  _check(isinstance(input.type, ttgl.distributed_type),
54
72
  lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
55
73
  layout = input.type.layout
56
- _check(isinstance(layout, SliceLayout),
74
+ _check(isinstance(layout, (SliceLayout, AutoLayout)),
57
75
  lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
58
- _check(layout.dim == axis,
59
- lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
76
+ _check(
77
+ isinstance(layout, AutoLayout) or layout.dim == axis,
78
+ lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
60
79
 
61
- ret_ty = ttgl.distributed_type(input.type.scalar, dst_shape, layout.parent)
62
- handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder))
63
- return self.tensor(handle, ret_ty)
80
+ handle = self.builder.create_expand_dims(input.handle, axis)
81
+ return self._wrap_handle_infer_layout(handle, input.type.scalar, dst_shape)
64
82
 
65
83
  def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
66
84
  a, b = self.broadcast_impl_value(a, b)
@@ -107,7 +125,14 @@ class GluonSemantic(TritonSemantic[TensorTy]):
107
125
  lhs_shape = lhs_ty.get_block_shapes()
108
126
  rhs_shape = rhs_ty.get_block_shapes()
109
127
  ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
110
- if lhs_ty.layout != rhs_ty.layout:
128
+
129
+ is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout)
130
+ is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout)
131
+ if is_lhs_auto and not is_rhs_auto:
132
+ lhs = self.set_auto_layout(lhs, rhs_ty.layout)
133
+ elif is_rhs_auto and not is_lhs_auto:
134
+ rhs = self.set_auto_layout(rhs, lhs_ty.layout)
135
+ elif lhs_ty.layout != rhs_ty.layout:
111
136
  raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
112
137
 
113
138
  lhs = self.broadcast_impl_shape(lhs, ret_shape)
@@ -116,6 +141,8 @@ class GluonSemantic(TritonSemantic[TensorTy]):
116
141
 
117
142
  def arange(self, start, end, layout):
118
143
  shape = [end - start]
144
+ if layout is None:
145
+ layout = AutoLayout()
119
146
  ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
120
147
  return super().arange(start, end, ret_ty=ret_ty)
121
148
 
@@ -131,14 +158,19 @@ class GluonSemantic(TritonSemantic[TensorTy]):
131
158
 
132
159
  def full(self, shape, value, dtype, layout):
133
160
  scalar = self.make_scalar(value, dtype)
161
+ if layout is None:
162
+ layout = AutoLayout()
134
163
  return self.splat(scalar, shape, layout)
135
164
 
136
- def convert_layout(self, value, layout):
165
+ def convert_layout(self, value, layout, assert_trivial=False):
137
166
  ty = value.type
138
167
  _check(isinstance(ty, ttgl.distributed_type),
139
168
  lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
140
169
  ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
141
- handle = self.builder.create_convert_layout(ret_ty.to_ir(self.builder), value.handle)
170
+ ret_ty_ir = ret_ty.to_ir(self.builder)
171
+ if assert_trivial and not self.builder.is_convert_layout_trivial(ret_ty_ir, value.handle):
172
+ raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial")
173
+ handle = self.builder.create_convert_layout(ret_ty_ir, value.handle)
142
174
  return ttgl.tensor(handle, ret_ty)
143
175
 
144
176
  def allocate_shared(self, element_ty, shape, layout, value):
@@ -155,30 +187,42 @@ class GluonSemantic(TritonSemantic[TensorTy]):
155
187
  return ttgl.tensor(handle, ret_ty)
156
188
 
157
189
  def shared_store(self, mem_desc, value):
190
+ assert value.shape == mem_desc.shape, f"source shape {value.shape} and destination shape {mem_desc.shape} must match"
191
+ assert value.dtype == mem_desc.dtype, f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match"
158
192
  self.builder.create_local_store(mem_desc.handle, value.handle)
159
193
 
160
194
  def shared_dealloc(self, mem_desc):
161
195
  self.builder.create_local_dealloc(mem_desc.handle)
162
196
 
163
- def _memdesc_subview(self, mem_desc, offsets, shape):
164
- layout = mem_desc.layout
165
- ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
166
- builder = self.builder
167
- handle = builder.create_memdesc_subview(ty.to_ir(builder), mem_desc.handle, offsets)
168
- return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
197
+ def set_auto_layout(self, value, layout):
198
+ src_ty = value.type
199
+ assert isinstance(layout,
200
+ DistributedLayout), f"set_auto_layout must set to a distributed layout but got {layout}"
201
+ assert isinstance(src_ty.layout,
202
+ AutoLayout), f"set_auto_layout input must have auto layout but got {value.type.layout}"
203
+ handle = self.builder.create_set_auto_layout(layout._to_ir(self.builder), value.handle)
204
+ res_ty = ttgl.distributed_type(src_ty.element_ty, src_ty.shape, layout)
205
+ return self.tensor(handle, res_ty)
169
206
 
170
207
  def memdesc_slice(self, mem_desc, start, length, dim):
171
- offsets = [self.builder.get_int32(0)] * mem_desc.rank
172
- offsets[dim] = self.to_tensor(start).handle
208
+ offsets = [0] * mem_desc.rank
209
+ offsets[dim] = start
173
210
  shape = list(mem_desc.shape)
174
211
  shape[dim] = length
175
- return self._memdesc_subview(mem_desc, offsets, shape)
212
+ layout = mem_desc.layout
213
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
214
+ builder = self.builder
215
+ handle = builder.create_memdesc_subslice(ty.to_ir(builder), mem_desc.handle, offsets)
216
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
176
217
 
177
218
  def memdesc_index(self, mem_desc, index):
178
219
  shape = mem_desc.shape[1:]
179
- offsets = [self.builder.get_int32(0)] * mem_desc.rank
180
- offsets[0] = self.to_tensor(index).handle
181
- return self._memdesc_subview(mem_desc, offsets, shape)
220
+ index = self.to_tensor(index).handle
221
+ layout = mem_desc.layout
222
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
223
+ builder = self.builder
224
+ handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index)
225
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
182
226
 
183
227
  def memdesc_trans(self, mem_desc, order):
184
228
  assert len(order) == len(
@@ -194,10 +238,26 @@ class GluonSemantic(TritonSemantic[TensorTy]):
194
238
  return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
195
239
  alloc_shape=new_alloc_shape, layout=layout)
196
240
 
197
- def memdesc_reshape(self, mem_desc, shape, layout):
198
- ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
199
- handle = self.builder.create_memdesc_reshape(ty.to_ir(self.builder), mem_desc.handle)
200
- return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
241
+ def memdesc_reshape(self, mem_desc, shape):
242
+ _check(
243
+ math.prod(shape) == math.prod(mem_desc.shape),
244
+ lambda: (f"memdesc_reshape total elements mismatch: "
245
+ f"{mem_desc.shape} -> {shape}"),
246
+ )
247
+
248
+ handle = self.builder.create_memdesc_reshape(mem_desc.handle, shape)
249
+ layout = self.builder.get_gluon_layout_from_memdesc(handle)
250
+ alloc_shape = mem_desc.type.alloc_shape
251
+ prefix_len = len(alloc_shape) - mem_desc.rank
252
+ new_alloc_shape = alloc_shape[:prefix_len] + list(shape)
253
+
254
+ return ttgl.shared_memory_descriptor(
255
+ handle,
256
+ element_ty=mem_desc.dtype,
257
+ shape=shape,
258
+ alloc_shape=new_alloc_shape,
259
+ layout=layout,
260
+ )
201
261
 
202
262
  def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
203
263
  ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
@@ -220,6 +280,27 @@ class GluonSemantic(TritonSemantic[TensorTy]):
220
280
  _check(all(l == l0 for l in layouts[1:]),
221
281
  lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
222
282
 
283
+ def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
284
+ reverse: bool) -> Tuple[TensorTy, ...]:
285
+ shape = inputs[0].type.shape
286
+ rank = len(shape)
287
+
288
+ assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
289
+
290
+ if axis < 0:
291
+ axis += rank
292
+
293
+ for t in inputs:
294
+ assert t.type.shape == shape, "all scan inputs must have the same shape"
295
+
296
+ scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
297
+ region_builder_fn(scan_op)
298
+ assert scan_op.verify()
299
+
300
+ return tuple(
301
+ self._wrap_handle_infer_layout(scan_op.get_result(i), inputs[i].type.scalar, shape)
302
+ for i in range(len(inputs)))
303
+
223
304
  def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
224
305
  _check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
225
306
  # get result shape
@@ -228,7 +309,6 @@ class GluonSemantic(TritonSemantic[TensorTy]):
228
309
  _check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
229
310
  self._check_same_layout(inputs)
230
311
  ret_shape = [s for i, s in enumerate(shape) if i != axis]
231
- ret_layout = SliceLayout(axis, inputs[0].type.layout)
232
312
  assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
233
313
 
234
314
  reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
@@ -236,11 +316,23 @@ class GluonSemantic(TritonSemantic[TensorTy]):
236
316
  assert reduce_op.verify()
237
317
 
238
318
  return tuple(
239
- self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
319
+ self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
240
320
  for i in range(len(inputs)))
241
321
 
242
- def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
243
- worker_num_regs: Sequence[int], generator):
322
+ def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> TensorTy:
323
+ _check(len(input.shape) == 1, lambda: "histogram only supports 1D input")
324
+ _check(input.dtype.is_int(), lambda: "histogram only supports integer input")
325
+ _check(layout is not None, lambda: "histogram requires a destination layout")
326
+ if mask is not None:
327
+ mask, input = self.broadcast_impl_value(mask, input)
328
+ _check(mask.type.scalar.is_bool(), lambda: "Mask must have boolean scalar type")
329
+ mask = mask.handle
330
+ layout_attr = layout._to_ir(self.builder)
331
+ handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr)
332
+ return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout)
333
+
334
+ def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
335
+ worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
244
336
  num_partitions = len(worker_partitions)
245
337
  assert num_partitions == len(
246
338
  worker_num_warps
@@ -255,7 +347,7 @@ class GluonSemantic(TritonSemantic[TensorTy]):
255
347
  # Emit the default partition to get the result types.
256
348
  default_block = builder.new_block()
257
349
  builder.set_insertion_point_to_start(default_block)
258
- default_results = generator.call_JitFunction(default_partition, args, kwargs={})
350
+ default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
259
351
  mlir_results = []
260
352
  if default_results is not None:
261
353
  mlir_results = flatten_values_to_ir(default_results)
@@ -264,7 +356,7 @@ class GluonSemantic(TritonSemantic[TensorTy]):
264
356
 
265
357
  # Create the warp specialize op.
266
358
  builder.restore_insertion_point(insert_pt)
267
- mlir_args = flatten_values_to_ir(args)
359
+ mlir_args = flatten_values_to_ir(worker_args)
268
360
  ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
269
361
  ws_op.get_default_region().push_back(default_block)
270
362
  ws_op.set_requested_registers(worker_num_regs)
@@ -274,10 +366,11 @@ class GluonSemantic(TritonSemantic[TensorTy]):
274
366
  partitions_op = builder.create_warp_specialize_partitions(num_partitions)
275
367
  arg_types = [arg.get_type() for arg in mlir_args]
276
368
  for i in range(num_partitions):
369
+ caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
277
370
  block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
278
371
  block_args = [block.get_argument(j) for j in range(len(mlir_args))]
279
- block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
280
- generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
372
+ block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
373
+ generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context)
281
374
  builder.create_warp_return()
282
375
 
283
376
  builder.set_insertion_point_after(ws_op.get_operation())
@@ -1,39 +1,60 @@
1
- # flake8: noqa
2
- import triton
1
+ from typing import TypeVar
2
+ from triton.runtime.jit import JITFunction
3
3
  import triton.language.standard as tl_standard
4
- from .._runtime import jit
4
+ from .._runtime import GluonJITFunction, jit
5
5
  from triton import knobs
6
6
  from . import _core as ttgl
7
7
 
8
- _IMPORT_FROM_TRITON = [
9
- "sum",
10
- "max",
11
- "min",
12
- "reduce_or",
13
- "xor_sum",
14
- ]
8
+ T = TypeVar("T")
15
9
 
16
- __all__ = [
17
- "full_like",
18
- "zeros",
19
- "zeros_like",
20
- *_IMPORT_FROM_TRITON,
21
- ]
22
10
 
23
- for name in _IMPORT_FROM_TRITON:
24
- # Convert JITFunction -> GluonJitFunction
25
- fn = getattr(tl_standard, name)
26
- assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction)
27
- globals()[name] = jit(fn.fn)
11
+ def _import_from_triton(fn: JITFunction[T]) -> GluonJITFunction[T]:
12
+ assert knobs.runtime.interpret or isinstance(fn, JITFunction)
13
+ # Wrap the function and preserve its original docstring
14
+ gluon_fn = jit(fn.fn)
15
+ gluon_fn.__doc__ = fn.__doc__
16
+ return gluon_fn
17
+
18
+
19
+ cdiv = _import_from_triton(tl_standard.cdiv)
20
+ sum = _import_from_triton(tl_standard.sum)
21
+ max = _import_from_triton(tl_standard.max)
22
+ min = _import_from_triton(tl_standard.min)
23
+ reduce_or = _import_from_triton(tl_standard.reduce_or)
24
+ xor_sum = _import_from_triton(tl_standard.xor_sum)
28
25
 
29
26
 
30
27
  @jit
31
- def zeros(shape, dtype, layout):
28
+ def zeros(shape, dtype, layout=None):
29
+ """
30
+ Create a tensor filled with zeros.
31
+
32
+ Args:
33
+ shape (Sequence[int]): The shape of the tensor.
34
+ dtype (dtype): The data type for the tensor.
35
+ layout (Optional[DistributedLayout]): The distributed layout of the tensor, defaults to AutoLayout().
36
+
37
+ Returns:
38
+ tensor: A tensor where every element is zero.
39
+ """
32
40
  return ttgl.full(shape, 0, dtype, layout)
33
41
 
34
42
 
35
43
  @jit
36
44
  def full_like(input, value, shape=None, dtype=None, layout=None):
45
+ """
46
+ Create a tensor with the same properties as a given tensor, filled with a specified value.
47
+
48
+ Args:
49
+ input (tensor): Reference tensor to infer default shape, dtype, and layout.
50
+ value (int or float): The fill value.
51
+ shape (Sequence[int], optional): Target shape. Defaults to input.shape.
52
+ dtype (dtype, optional): Target data type. Defaults to input.dtype.
53
+ layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
54
+
55
+ Returns:
56
+ tensor: A tensor where every element equals value.
57
+ """
37
58
  return ttgl.full(
38
59
  input.shape if shape is None else shape,
39
60
  value,
@@ -44,4 +65,16 @@ def full_like(input, value, shape=None, dtype=None, layout=None):
44
65
 
45
66
  @jit
46
67
  def zeros_like(input, shape=None, dtype=None, layout=None):
68
+ """
69
+ Create a tensor with the same properties as a given tensor, filled with zeros.
70
+
71
+ Args:
72
+ input (tensor): Reference tensor to infer default shape, dtype, and layout.
73
+ shape (Sequence[int], optional): Target shape. Defaults to input.shape.
74
+ dtype (dtype, optional): Target data type. Defaults to input.dtype.
75
+ layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
76
+
77
+ Returns:
78
+ tensor: A tensor where every element is zero.
79
+ """
47
80
  return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)
@@ -0,0 +1,4 @@
1
+ from ._layouts import AMDMFMALayout
2
+ from . import cdna3, cdna4
3
+
4
+ __all__ = ["AMDMFMALayout", "cdna3", "cdna4"]
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+ from triton.language.core import _unwrap_if_constexpr
6
+
7
+ from triton.experimental.gluon.language._layouts import _realize_cta_layout, DistributedLayout
8
+ from triton.experimental.gluon import language as ttgl
9
+
10
+ __all__ = [
11
+ "AMDMFMALayout",
12
+ ]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class AMDMFMALayout(DistributedLayout):
17
+ """
18
+ Represents a layout for AMD MFMA (matrix core) operations.
19
+
20
+ Args:
21
+ version (int): Major and minor identifier for the MFMA instruction.
22
+ instr_shape: (M, N) dimension for the instrinsic shape.
23
+ transposed (bool): indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
24
+ warps_per_cta (List[int]): Number of warps per CTA.
25
+ elem_type Optional(ttgl.dtype): Supported types are int32, fp32 and fp64. Default is fp32.
26
+ tiles_per_warp Optional(List[int]): Number of tiles per WARP. For mfma layout, if missing, use the default where we have unit tile size on all dimensions.
27
+ ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
28
+ cta_split_num (Optional[List[int]]): Split factors for CTAs.
29
+ cta_order (Optional[List[int]]): CTA ordering.
30
+ """
31
+ version: int
32
+ instr_shape: List[int]
33
+ transposed: bool
34
+ warps_per_cta: List[int]
35
+ elem_type: ttgl.dtype = ttgl.float32
36
+ tiles_per_warp: Optional[List[int]] = None
37
+ ctas_per_cga: Optional[List[int]] = None
38
+ cta_split_num: Optional[List[int]] = None
39
+ cta_order: Optional[List[int]] = None
40
+
41
+ def __post_init__(self):
42
+ super().__setattr__("version", _unwrap_if_constexpr(self.version))
43
+ super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
44
+ super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
45
+ super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
46
+ super().__setattr__("tiles_per_warp", _unwrap_if_constexpr(self.tiles_per_warp))
47
+ super().__setattr__("elem_type", _unwrap_if_constexpr(self.elem_type))
48
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
49
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
50
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
51
+
52
+ if self.tiles_per_warp is None:
53
+ object.__setattr__(self, "tiles_per_warp", [1] * len(self.warps_per_cta))
54
+
55
+ self.verify()
56
+
57
+ def _to_ir(self, builder):
58
+ type = self.elem_type.to_ir(builder)
59
+ return builder.get_amd_mfma_layout(self.version, self.instr_shape, self.transposed, self.warps_per_cta, type,
60
+ self.tiles_per_warp, self.ctas_per_cga, self.cta_split_num, self.cta_order)
61
+
62
+ def mangle(self) -> str:
63
+
64
+ def stringify(x):
65
+ if x is None:
66
+ return ""
67
+ return "_".join(map(str, x))
68
+
69
+ return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.tiles_per_warp)}_{self.elem_type}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_MFMA"
70
+
71
+ def verify(self):
72
+ assert self.version >= 1 and self.version <= 4, "version must be in the [1, 4] range"
73
+ valid_shapes = [[32, 32], [16, 16], [64, 4], [4, 64]]
74
+ assert self.instr_shape in valid_shapes, "invalid intrinsic shape; accepted shapes are " + str(valid_shapes)
75
+
76
+ assert self.elem_type.is_fp32() or self.elem_type.is_fp64() \
77
+ or self.elem_type.is_int32() , "element type must be float32, float64, or int32"
78
+
79
+ rank = len(self.warps_per_cta)
80
+ _realize_cta_layout(self, rank)
81
+ assert len(self.ctas_per_cga) == rank
82
+ assert len(self.cta_split_num) == rank
83
+ assert len(self.cta_order) == rank
84
+
85
+ def __hash__(self):
86
+ return hash((
87
+ self.version,
88
+ tuple(self.instr_shape),
89
+ self.transposed,
90
+ tuple(self.warps_per_cta),
91
+ self.elem_type,
92
+ tuple(self.tiles_per_warp) if self.tiles_per_warp else None,
93
+ tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
94
+ tuple(self.cta_split_num) if self.cta_split_num else None,
95
+ tuple(self.cta_order) if self.cta_order else None,
96
+ ))
@@ -0,0 +1,100 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+
4
+ from triton import knobs
5
+ from triton.experimental.gluon.language import _core as ttgl
6
+ from triton._C.libtriton import ir
7
+ from ..._core import builtin, _unwrap_if_constexpr
8
+
9
+ if TYPE_CHECKING:
10
+ from ..._semantic import GluonSemantic
11
+
12
+ __all__ = ["buffer_load", "buffer_store", "mfma"]
13
+
14
+
15
+ def _verify_buffer_ops(ptr, offsets, mask=None, other=None):
16
+ assert ptr.type.is_ptr(), "ptr must be a scalar pointer type"
17
+
18
+ assert isinstance(offsets.type, ttgl.distributed_type), "expected offsets type to be a distributed_type"
19
+ assert offsets.dtype.is_int32() or offsets.dtype.is_uint32(), "offsets element type must be int32 or uint32"
20
+
21
+ element_type = ptr.type.scalar.element_ty
22
+
23
+ if other is not None:
24
+ assert mask is not None, "when other is not None, mask should not be None"
25
+ assert other.dtype == element_type, "other must have the same data type as ptr scalar type"
26
+
27
+
28
+ @builtin
29
+ def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None):
30
+ """
31
+ AMD buffer load from global memory via a scalar base pointer and a tensor of
32
+ offsets instead of a tensor of pointers. This operation will load data
33
+ directly into registers.
34
+
35
+ Args:
36
+ ptr (pointer to scalar): Global memory scalar base pointer to load from.
37
+ offsets (tensor): Offsets tensor for the load operation.
38
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
39
+ other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
40
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
41
+ """
42
+ _verify_buffer_ops(ptr, offsets, mask, other)
43
+
44
+ mask = _unwrap_if_constexpr(mask)
45
+ if mask is not None:
46
+ offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
47
+
48
+ other = _unwrap_if_constexpr(other)
49
+ if other is not None:
50
+ offsets, other = _semantic.broadcast_impl_value(offsets, other)
51
+
52
+ other = other.handle if other is not None else ir.value()
53
+ mask = mask.handle if mask is not None else ir.value()
54
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
55
+
56
+ ret_ty = offsets.type.with_element_ty(ptr.type.scalar.element_ty)
57
+ builder = _semantic.builder
58
+ handle = builder.create_buffer_load(ret_ty.to_ir(builder), ptr.handle, offsets.handle, mask, other, cache_modifier)
59
+ return ttgl.tensor(handle, ret_ty)
60
+
61
+
62
+ @builtin
63
+ def buffer_store(stored_value, ptr, offsets, mask=None, cache=None, _semantic: GluonSemantic = None):
64
+ """
65
+ AMD buffer store a tensor directly to global memory via a scalar base pointer and a tensor of
66
+ offsets instead of a tensor of pointers.
67
+ Args:
68
+ stored_value (tensor to be stored): The tensor to be stored to global memory.
69
+ ptr (pointer to scalar): Global memory scalar base pointer to store to.
70
+ offsets (tensor): Offsets tensor for the store operation.
71
+ mask (tensor, optional): Mask tensor for predicated store. Defaults to None.
72
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
73
+ """
74
+ _verify_buffer_ops(ptr, offsets, mask)
75
+
76
+ if mask is not None:
77
+ offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
78
+
79
+ mask = mask.handle if mask is not None else ir.value()
80
+ cache_modifier = _semantic._str_to_store_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
81
+
82
+ _semantic.builder.create_buffer_store(stored_value.handle, ptr.handle, offsets.handle, mask, cache_modifier)
83
+
84
+
85
+ @builtin
86
+ def mfma(a, b, acc, _semantic: GluonSemantic = None):
87
+ """
88
+ Computes matrix-multiplication of a * b + acc using AMD native matrix core units.
89
+ Args:
90
+ a (tensor): The first operand of mfma.
91
+ b (tensor): The second operand of mfma.
92
+ acc (tensor): The accumulator tensor.
93
+ """
94
+ assert acc is not None, "acc is required"
95
+ ret_type = acc.type
96
+ acc = ttgl._unwrap_if_constexpr(acc)
97
+
98
+ handle = _semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None,
99
+ out_dtype=acc.dtype).handle
100
+ return ttgl.tensor(handle, ret_type)
@@ -0,0 +1,48 @@
1
+ from triton.experimental.gluon.language import _core as ttgl
2
+ from ..._core import builtin, float32
3
+ from ..._layouts import DotOperandLayout
4
+ from .._layouts import AMDMFMALayout
5
+ from ..cdna3 import * # NOQA: F403
6
+ from ..cdna3 import __all__ as __cdna3_all
7
+ from . import async_copy
8
+
9
+ __all__ = [*__cdna3_all, "async_copy", "mfma_scaled"]
10
+
11
+
12
+ @builtin
13
+ def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None):
14
+ """
15
+ AMD Scaled MFMA operation.
16
+
17
+ ```
18
+ c = a * a_scale @ b * b_scale + acc
19
+ ```
20
+
21
+ `a` and `b` use microscaling formats described in
22
+ "OCP Microscaling Formats (MX) Specification":
23
+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf.
24
+ Currently supported only on CDNA4 hardware.
25
+
26
+ Args:
27
+ a (tensor): The operand A to be multiplied.
28
+ a_scale (tensor): Scale factor for operand A.
29
+ a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`.
30
+ b (tensor): The operand B to be multiplied.
31
+ b_scale (tensor): Scale factor for operand B. Available formats: `e2m1`, `e4m3`, `e5m2`.
32
+ b_format (str): Format of the operand B.
33
+ acc (tensor): Accumulator tensor.
34
+ """
35
+ layout = acc.type.layout
36
+ assert isinstance(layout, AMDMFMALayout), "Expected layout to be an instance of AMDMFMALayout"
37
+ assert (isinstance(a.type.layout, DotOperandLayout) and a.type.layout.parent== layout), \
38
+ "Expected lhs layout to be a DotOperandLayout with parent matching MFMA layout"
39
+ assert (isinstance(b.type.layout, DotOperandLayout) and b.type.layout.parent == layout), \
40
+ "Expected rhs layout to be a DotOperandLayout with parent matching MFMA layout"
41
+
42
+ assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
43
+ assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
44
+
45
+ tensor = _semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, False, True, True, float32)
46
+
47
+ ret_ty = ttgl.distributed_type(tensor.dtype, tensor.shape, layout)
48
+ return ttgl.tensor(tensor.handle, ret_ty)