triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -0,0 +1,380 @@
1
+ from typing import Sequence, List, TypeVar, Tuple, Callable
2
+ import math
3
+ from triton.language.semantic import TritonSemantic
4
+ from . import _core as ttgl
5
+ from ._layouts import AutoLayout, DistributedLayout, SliceLayout
6
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
7
+ from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
8
+
9
+ TensorTy = TypeVar("TensorTy")
10
+
11
+
12
+ def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
13
+ if not cond:
14
+ raise category(msg_fn())
15
+
16
+
17
+ class GluonCallerContext:
18
+
19
+ def __init__(self, num_warps: int):
20
+ self.num_warps = num_warps
21
+
22
+ def mangle(self):
23
+ return f"_NW{self.num_warps}"
24
+
25
+ def initialize_callee(self, fn, builder):
26
+ fn.set_attr("ttg.num-warps", builder.get_int32_attr(self.num_warps))
27
+
28
+
29
+ class GluonSemantic(TritonSemantic[TensorTy]):
30
+ tensor = ttgl.tensor
31
+ lang = ttgl
32
+
33
+ builder: GluonOpBuilder
34
+
35
+ def __init__(self, builder: GluonOpBuilder):
36
+ self.builder = builder
37
+
38
+ def _wrap_handle_infer_layout(self, handle, scalar_ty, shape):
39
+ if shape == []:
40
+ ty = scalar_ty
41
+ else:
42
+ ty = ttgl.distributed_type(scalar_ty, shape, self.builder.get_gluon_layout_from_tensor(handle))
43
+ return self.tensor(handle, ty)
44
+
45
+ def _wrap_tensor_infer_layout(self, tensor):
46
+ return self._wrap_handle_infer_layout(tensor.handle, tensor.type.scalar, tensor.shape)
47
+
48
+ def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
49
+ if len(lhs_shape) != len(rhs_shape):
50
+ raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")
51
+
52
+ ret_shape = []
53
+ for i, left in enumerate(lhs_shape):
54
+ right = rhs_shape[i]
55
+ if left == 1:
56
+ ret_shape.append(right)
57
+ elif (right == 1) or (right == left):
58
+ ret_shape.append(left)
59
+ else:
60
+ raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
61
+ "at index " + str(i) + ": " + str(left) + " and " + str(right))
62
+ return ret_shape
63
+
64
+ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
65
+ dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape]
66
+ dst_shape.insert(axis, 1)
67
+
68
+ if axis < 0:
69
+ axis += len(input.shape)
70
+
71
+ _check(isinstance(input.type, ttgl.distributed_type),
72
+ lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
73
+ layout = input.type.layout
74
+ _check(isinstance(layout, (SliceLayout, AutoLayout)),
75
+ lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
76
+ _check(
77
+ isinstance(layout, AutoLayout) or layout.dim == axis,
78
+ lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
79
+
80
+ handle = self.builder.create_expand_dims(input.handle, axis)
81
+ return self._wrap_handle_infer_layout(handle, input.type.scalar, dst_shape)
82
+
83
+ def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
84
+ a, b = self.broadcast_impl_value(a, b)
85
+ _check(a.shape != [], "Cannot join scalars in gluon")
86
+ value = super().join(a, b)
87
+ return self._wrap_tensor_infer_layout(value)
88
+
89
+ def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
90
+ lhs, rhs = super().split(a)
91
+ return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs)
92
+
93
+ def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
94
+ value = super().permute(input, dims)
95
+ return self._wrap_tensor_infer_layout(value)
96
+
97
+ def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
98
+ _check(isinstance(input.type, ttgl.distributed_type),
99
+ lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
100
+ src_shape = input.type.get_block_shapes()
101
+ _check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
102
+ if shape == src_shape:
103
+ return input
104
+ for i, item in enumerate(src_shape):
105
+ if shape[i] != item and item != 1:
106
+ raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
107
+ f" must match the existing size ({item}) at non-singleton dimension"
108
+ f" {i}: {src_shape}, {shape}")
109
+ ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout)
110
+ handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder))
111
+ return self.tensor(handle, ret_ty)
112
+
113
+ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
114
+ lhs_ty = lhs.type
115
+ rhs_ty = rhs.type
116
+
117
+ if not lhs_ty.is_block() or not rhs_ty.is_block():
118
+ return super().broadcast_impl_value(lhs, rhs)
119
+
120
+ _check(isinstance(lhs_ty, ttgl.distributed_type),
121
+ lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}")
122
+ _check(isinstance(rhs_ty, ttgl.distributed_type),
123
+ lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}")
124
+
125
+ lhs_shape = lhs_ty.get_block_shapes()
126
+ rhs_shape = rhs_ty.get_block_shapes()
127
+ ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
128
+
129
+ is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout)
130
+ is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout)
131
+ if is_lhs_auto and not is_rhs_auto:
132
+ lhs = self.set_auto_layout(lhs, rhs_ty.layout)
133
+ elif is_rhs_auto and not is_lhs_auto:
134
+ rhs = self.set_auto_layout(rhs, lhs_ty.layout)
135
+ elif lhs_ty.layout != rhs_ty.layout:
136
+ raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
137
+
138
+ lhs = self.broadcast_impl_shape(lhs, ret_shape)
139
+ rhs = self.broadcast_impl_shape(rhs, ret_shape)
140
+ return lhs, rhs
141
+
142
+ def arange(self, start, end, layout):
143
+ shape = [end - start]
144
+ if layout is None:
145
+ layout = AutoLayout()
146
+ ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
147
+ return super().arange(start, end, ret_ty=ret_ty)
148
+
149
+ def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
150
+ _check(not can_reorder, "can_reorder is not supported in gluon")
151
+ value = super().reshape(input, dst_shape, can_reorder)
152
+ return self._wrap_tensor_infer_layout(value)
153
+
154
+ def splat(self, value, shape, layout):
155
+ ret_ty = ttgl.distributed_type(value.dtype, shape, layout)
156
+ handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle)
157
+ return ttgl.tensor(handle, ret_ty)
158
+
159
+ def full(self, shape, value, dtype, layout):
160
+ scalar = self.make_scalar(value, dtype)
161
+ if layout is None:
162
+ layout = AutoLayout()
163
+ return self.splat(scalar, shape, layout)
164
+
165
+ def convert_layout(self, value, layout, assert_trivial=False):
166
+ ty = value.type
167
+ _check(isinstance(ty, ttgl.distributed_type),
168
+ lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
169
+ ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
170
+ ret_ty_ir = ret_ty.to_ir(self.builder)
171
+ if assert_trivial and not self.builder.is_convert_layout_trivial(ret_ty_ir, value.handle):
172
+ raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial")
173
+ handle = self.builder.create_convert_layout(ret_ty_ir, value.handle)
174
+ return ttgl.tensor(handle, ret_ty)
175
+
176
+ def allocate_shared(self, element_ty, shape, layout, value):
177
+ ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape)
178
+ if value is not None:
179
+ handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle)
180
+ else:
181
+ handle = self.builder.create_local_alloc(ty.to_ir(self.builder))
182
+ return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape)
183
+
184
+ def shared_load(self, mem_desc, layout):
185
+ ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout)
186
+ handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle)
187
+ return ttgl.tensor(handle, ret_ty)
188
+
189
+ def shared_store(self, mem_desc, value):
190
+ assert value.shape == mem_desc.shape, f"source shape {value.shape} and destination shape {mem_desc.shape} must match"
191
+ assert value.dtype == mem_desc.dtype, f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match"
192
+ self.builder.create_local_store(mem_desc.handle, value.handle)
193
+
194
+ def shared_dealloc(self, mem_desc):
195
+ self.builder.create_local_dealloc(mem_desc.handle)
196
+
197
+ def set_auto_layout(self, value, layout):
198
+ src_ty = value.type
199
+ assert isinstance(layout,
200
+ DistributedLayout), f"set_auto_layout must set to a distributed layout but got {layout}"
201
+ assert isinstance(src_ty.layout,
202
+ AutoLayout), f"set_auto_layout input must have auto layout but got {value.type.layout}"
203
+ handle = self.builder.create_set_auto_layout(layout._to_ir(self.builder), value.handle)
204
+ res_ty = ttgl.distributed_type(src_ty.element_ty, src_ty.shape, layout)
205
+ return self.tensor(handle, res_ty)
206
+
207
+ def memdesc_slice(self, mem_desc, start, length, dim):
208
+ offsets = [0] * mem_desc.rank
209
+ offsets[dim] = start
210
+ shape = list(mem_desc.shape)
211
+ shape[dim] = length
212
+ layout = mem_desc.layout
213
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
214
+ builder = self.builder
215
+ handle = builder.create_memdesc_subslice(ty.to_ir(builder), mem_desc.handle, offsets)
216
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
217
+
218
+ def memdesc_index(self, mem_desc, index):
219
+ shape = mem_desc.shape[1:]
220
+ index = self.to_tensor(index).handle
221
+ layout = mem_desc.layout
222
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
223
+ builder = self.builder
224
+ handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index)
225
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
226
+
227
+ def memdesc_trans(self, mem_desc, order):
228
+ assert len(order) == len(
229
+ mem_desc.shape), f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match"
230
+
231
+ shape = [mem_desc.shape[i] for i in order]
232
+ alloc_shape = mem_desc.type.alloc_shape
233
+ new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
234
+ new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
235
+
236
+ handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
237
+ layout = self.builder.get_gluon_layout_from_memdesc(handle)
238
+ return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
239
+ alloc_shape=new_alloc_shape, layout=layout)
240
+
241
+ def memdesc_reshape(self, mem_desc, shape):
242
+ _check(
243
+ math.prod(shape) == math.prod(mem_desc.shape),
244
+ lambda: (f"memdesc_reshape total elements mismatch: "
245
+ f"{mem_desc.shape} -> {shape}"),
246
+ )
247
+
248
+ handle = self.builder.create_memdesc_reshape(mem_desc.handle, shape)
249
+ layout = self.builder.get_gluon_layout_from_memdesc(handle)
250
+ alloc_shape = mem_desc.type.alloc_shape
251
+ prefix_len = len(alloc_shape) - mem_desc.rank
252
+ new_alloc_shape = alloc_shape[:prefix_len] + list(shape)
253
+
254
+ return ttgl.shared_memory_descriptor(
255
+ handle,
256
+ element_ty=mem_desc.dtype,
257
+ shape=shape,
258
+ alloc_shape=new_alloc_shape,
259
+ layout=layout,
260
+ )
261
+
262
+ def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
263
+ ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
264
+ handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle)
265
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
266
+
267
+ def wrap_tensor(self, x, scalar_ty, ret_shape, layout):
268
+ if ret_shape:
269
+ res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout)
270
+ else:
271
+ res_ty = scalar_ty
272
+ return self.tensor(x, res_ty)
273
+
274
+ @staticmethod
275
+ def _check_same_layout(xs):
276
+ for x in xs:
277
+ _check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}")
278
+ layouts = [x.type.layout for x in xs]
279
+ l0 = layouts[0]
280
+ _check(all(l == l0 for l in layouts[1:]),
281
+ lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
282
+
283
+ def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
284
+ reverse: bool) -> Tuple[TensorTy, ...]:
285
+ shape = inputs[0].type.shape
286
+ rank = len(shape)
287
+
288
+ assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
289
+
290
+ if axis < 0:
291
+ axis += rank
292
+
293
+ for t in inputs:
294
+ assert t.type.shape == shape, "all scan inputs must have the same shape"
295
+
296
+ scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
297
+ region_builder_fn(scan_op)
298
+ assert scan_op.verify()
299
+
300
+ return tuple(
301
+ self._wrap_handle_infer_layout(scan_op.get_result(i), inputs[i].type.scalar, shape)
302
+ for i in range(len(inputs)))
303
+
304
+ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
305
+ _check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
306
+ # get result shape
307
+ shape = inputs[0].type.shape
308
+ rank = len(shape)
309
+ _check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
310
+ self._check_same_layout(inputs)
311
+ ret_shape = [s for i, s in enumerate(shape) if i != axis]
312
+ assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
313
+
314
+ reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
315
+ region_builder_fn(reduce_op)
316
+ assert reduce_op.verify()
317
+
318
+ return tuple(
319
+ self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
320
+ for i in range(len(inputs)))
321
+
322
+ def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> TensorTy:
323
+ _check(len(input.shape) == 1, lambda: "histogram only supports 1D input")
324
+ _check(input.dtype.is_int(), lambda: "histogram only supports integer input")
325
+ _check(layout is not None, lambda: "histogram requires a destination layout")
326
+ if mask is not None:
327
+ mask, input = self.broadcast_impl_value(mask, input)
328
+ _check(mask.type.scalar.is_bool(), lambda: "Mask must have boolean scalar type")
329
+ mask = mask.handle
330
+ layout_attr = layout._to_ir(self.builder)
331
+ handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr)
332
+ return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout)
333
+
334
+ def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
335
+ worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
336
+ num_partitions = len(worker_partitions)
337
+ assert num_partitions == len(
338
+ worker_num_warps
339
+ ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"
340
+ assert num_partitions == len(
341
+ worker_num_regs
342
+ ), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts"
343
+
344
+ builder = self.builder
345
+ insert_pt = builder.get_insertion_point()
346
+
347
+ # Emit the default partition to get the result types.
348
+ default_block = builder.new_block()
349
+ builder.set_insertion_point_to_start(default_block)
350
+ default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
351
+ mlir_results = []
352
+ if default_results is not None:
353
+ mlir_results = flatten_values_to_ir(default_results)
354
+ builder.create_warp_yield(mlir_results)
355
+ result_types = [r.get_type() for r in mlir_results]
356
+
357
+ # Create the warp specialize op.
358
+ builder.restore_insertion_point(insert_pt)
359
+ mlir_args = flatten_values_to_ir(worker_args)
360
+ ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
361
+ ws_op.get_default_region().push_back(default_block)
362
+ ws_op.set_requested_registers(worker_num_regs)
363
+
364
+ # Emit the partition regions.
365
+ builder.create_block_with_parent(ws_op.get_partition_op_holder(), [])
366
+ partitions_op = builder.create_warp_specialize_partitions(num_partitions)
367
+ arg_types = [arg.get_type() for arg in mlir_args]
368
+ for i in range(num_partitions):
369
+ caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
370
+ block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
371
+ block_args = [block.get_argument(j) for j in range(len(mlir_args))]
372
+ block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
373
+ generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context)
374
+ builder.create_warp_return()
375
+
376
+ builder.set_insertion_point_after(ws_op.get_operation())
377
+ mlir_results = [ws_op.get_result(i) for i in range(len(result_types))]
378
+ if default_results is None:
379
+ return
380
+ return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
@@ -0,0 +1,80 @@
1
+ from typing import TypeVar
2
+ from triton.runtime.jit import JITFunction
3
+ import triton.language.standard as tl_standard
4
+ from .._runtime import GluonJITFunction, jit
5
+ from triton import knobs
6
+ from . import _core as ttgl
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ def _import_from_triton(fn: JITFunction[T]) -> GluonJITFunction[T]:
12
+ assert knobs.runtime.interpret or isinstance(fn, JITFunction)
13
+ # Wrap the function and preserve its original docstring
14
+ gluon_fn = jit(fn.fn)
15
+ gluon_fn.__doc__ = fn.__doc__
16
+ return gluon_fn
17
+
18
+
19
+ cdiv = _import_from_triton(tl_standard.cdiv)
20
+ sum = _import_from_triton(tl_standard.sum)
21
+ max = _import_from_triton(tl_standard.max)
22
+ min = _import_from_triton(tl_standard.min)
23
+ reduce_or = _import_from_triton(tl_standard.reduce_or)
24
+ xor_sum = _import_from_triton(tl_standard.xor_sum)
25
+
26
+
27
+ @jit
28
+ def zeros(shape, dtype, layout=None):
29
+ """
30
+ Create a tensor filled with zeros.
31
+
32
+ Args:
33
+ shape (Sequence[int]): The shape of the tensor.
34
+ dtype (dtype): The data type for the tensor.
35
+ layout (Optional[DistributedLayout]): The distributed layout of the tensor, defaults to AutoLayout().
36
+
37
+ Returns:
38
+ tensor: A tensor where every element is zero.
39
+ """
40
+ return ttgl.full(shape, 0, dtype, layout)
41
+
42
+
43
+ @jit
44
+ def full_like(input, value, shape=None, dtype=None, layout=None):
45
+ """
46
+ Create a tensor with the same properties as a given tensor, filled with a specified value.
47
+
48
+ Args:
49
+ input (tensor): Reference tensor to infer default shape, dtype, and layout.
50
+ value (int or float): The fill value.
51
+ shape (Sequence[int], optional): Target shape. Defaults to input.shape.
52
+ dtype (dtype, optional): Target data type. Defaults to input.dtype.
53
+ layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
54
+
55
+ Returns:
56
+ tensor: A tensor where every element equals value.
57
+ """
58
+ return ttgl.full(
59
+ input.shape if shape is None else shape,
60
+ value,
61
+ input.dtype if dtype is None else dtype,
62
+ input.type.layout if layout is None else layout,
63
+ )
64
+
65
+
66
+ @jit
67
+ def zeros_like(input, shape=None, dtype=None, layout=None):
68
+ """
69
+ Create a tensor with the same properties as a given tensor, filled with zeros.
70
+
71
+ Args:
72
+ input (tensor): Reference tensor to infer default shape, dtype, and layout.
73
+ shape (Sequence[int], optional): Target shape. Defaults to input.shape.
74
+ dtype (dtype, optional): Target data type. Defaults to input.dtype.
75
+ layout (DistributedLayout, optional): Target layout. Defaults to input.layout.
76
+
77
+ Returns:
78
+ tensor: A tensor where every element is zero.
79
+ """
80
+ return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)
@@ -0,0 +1,4 @@
1
+ from ._layouts import AMDMFMALayout
2
+ from . import cdna3, cdna4
3
+
4
+ __all__ = ["AMDMFMALayout", "cdna3", "cdna4"]
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+ from triton.language.core import _unwrap_if_constexpr
6
+
7
+ from triton.experimental.gluon.language._layouts import _realize_cta_layout, DistributedLayout
8
+ from triton.experimental.gluon import language as ttgl
9
+
10
+ __all__ = [
11
+ "AMDMFMALayout",
12
+ ]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class AMDMFMALayout(DistributedLayout):
17
+ """
18
+ Represents a layout for AMD MFMA (matrix core) operations.
19
+
20
+ Args:
21
+ version (int): Major and minor identifier for the MFMA instruction.
22
+ instr_shape: (M, N) dimension for the instrinsic shape.
23
+ transposed (bool): indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
24
+ warps_per_cta (List[int]): Number of warps per CTA.
25
+ elem_type Optional(ttgl.dtype): Supported types are int32, fp32 and fp64. Default is fp32.
26
+ tiles_per_warp Optional(List[int]): Number of tiles per WARP. For mfma layout, if missing, use the default where we have unit tile size on all dimensions.
27
+ ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
28
+ cta_split_num (Optional[List[int]]): Split factors for CTAs.
29
+ cta_order (Optional[List[int]]): CTA ordering.
30
+ """
31
+ version: int
32
+ instr_shape: List[int]
33
+ transposed: bool
34
+ warps_per_cta: List[int]
35
+ elem_type: ttgl.dtype = ttgl.float32
36
+ tiles_per_warp: Optional[List[int]] = None
37
+ ctas_per_cga: Optional[List[int]] = None
38
+ cta_split_num: Optional[List[int]] = None
39
+ cta_order: Optional[List[int]] = None
40
+
41
+ def __post_init__(self):
42
+ super().__setattr__("version", _unwrap_if_constexpr(self.version))
43
+ super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
44
+ super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
45
+ super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
46
+ super().__setattr__("tiles_per_warp", _unwrap_if_constexpr(self.tiles_per_warp))
47
+ super().__setattr__("elem_type", _unwrap_if_constexpr(self.elem_type))
48
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
49
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
50
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
51
+
52
+ if self.tiles_per_warp is None:
53
+ object.__setattr__(self, "tiles_per_warp", [1] * len(self.warps_per_cta))
54
+
55
+ self.verify()
56
+
57
+ def _to_ir(self, builder):
58
+ type = self.elem_type.to_ir(builder)
59
+ return builder.get_amd_mfma_layout(self.version, self.instr_shape, self.transposed, self.warps_per_cta, type,
60
+ self.tiles_per_warp, self.ctas_per_cga, self.cta_split_num, self.cta_order)
61
+
62
+ def mangle(self) -> str:
63
+
64
+ def stringify(x):
65
+ if x is None:
66
+ return ""
67
+ return "_".join(map(str, x))
68
+
69
+ return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.tiles_per_warp)}_{self.elem_type}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_MFMA"
70
+
71
+ def verify(self):
72
+ assert self.version >= 1 and self.version <= 4, "version must be in the [1, 4] range"
73
+ valid_shapes = [[32, 32], [16, 16], [64, 4], [4, 64]]
74
+ assert self.instr_shape in valid_shapes, "invalid intrinsic shape; accepted shapes are " + str(valid_shapes)
75
+
76
+ assert self.elem_type.is_fp32() or self.elem_type.is_fp64() \
77
+ or self.elem_type.is_int32() , "element type must be float32, float64, or int32"
78
+
79
+ rank = len(self.warps_per_cta)
80
+ _realize_cta_layout(self, rank)
81
+ assert len(self.ctas_per_cga) == rank
82
+ assert len(self.cta_split_num) == rank
83
+ assert len(self.cta_order) == rank
84
+
85
+ def __hash__(self):
86
+ return hash((
87
+ self.version,
88
+ tuple(self.instr_shape),
89
+ self.transposed,
90
+ tuple(self.warps_per_cta),
91
+ self.elem_type,
92
+ tuple(self.tiles_per_warp) if self.tiles_per_warp else None,
93
+ tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
94
+ tuple(self.cta_split_num) if self.cta_split_num else None,
95
+ tuple(self.cta_order) if self.cta_order else None,
96
+ ))
@@ -0,0 +1,100 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+
4
+ from triton import knobs
5
+ from triton.experimental.gluon.language import _core as ttgl
6
+ from triton._C.libtriton import ir
7
+ from ..._core import builtin, _unwrap_if_constexpr
8
+
9
+ if TYPE_CHECKING:
10
+ from ..._semantic import GluonSemantic
11
+
12
+ __all__ = ["buffer_load", "buffer_store", "mfma"]
13
+
14
+
15
+ def _verify_buffer_ops(ptr, offsets, mask=None, other=None):
16
+ assert ptr.type.is_ptr(), "ptr must be a scalar pointer type"
17
+
18
+ assert isinstance(offsets.type, ttgl.distributed_type), "expected offsets type to be a distributed_type"
19
+ assert offsets.dtype.is_int32() or offsets.dtype.is_uint32(), "offsets element type must be int32 or uint32"
20
+
21
+ element_type = ptr.type.scalar.element_ty
22
+
23
+ if other is not None:
24
+ assert mask is not None, "when other is not None, mask should not be None"
25
+ assert other.dtype == element_type, "other must have the same data type as ptr scalar type"
26
+
27
+
28
+ @builtin
29
+ def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None):
30
+ """
31
+ AMD buffer load from global memory via a scalar base pointer and a tensor of
32
+ offsets instead of a tensor of pointers. This operation will load data
33
+ directly into registers.
34
+
35
+ Args:
36
+ ptr (pointer to scalar): Global memory scalar base pointer to load from.
37
+ offsets (tensor): Offsets tensor for the load operation.
38
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
39
+ other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
40
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
41
+ """
42
+ _verify_buffer_ops(ptr, offsets, mask, other)
43
+
44
+ mask = _unwrap_if_constexpr(mask)
45
+ if mask is not None:
46
+ offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
47
+
48
+ other = _unwrap_if_constexpr(other)
49
+ if other is not None:
50
+ offsets, other = _semantic.broadcast_impl_value(offsets, other)
51
+
52
+ other = other.handle if other is not None else ir.value()
53
+ mask = mask.handle if mask is not None else ir.value()
54
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
55
+
56
+ ret_ty = offsets.type.with_element_ty(ptr.type.scalar.element_ty)
57
+ builder = _semantic.builder
58
+ handle = builder.create_buffer_load(ret_ty.to_ir(builder), ptr.handle, offsets.handle, mask, other, cache_modifier)
59
+ return ttgl.tensor(handle, ret_ty)
60
+
61
+
62
+ @builtin
63
+ def buffer_store(stored_value, ptr, offsets, mask=None, cache=None, _semantic: GluonSemantic = None):
64
+ """
65
+ AMD buffer store a tensor directly to global memory via a scalar base pointer and a tensor of
66
+ offsets instead of a tensor of pointers.
67
+ Args:
68
+ stored_value (tensor to be stored): The tensor to be stored to global memory.
69
+ ptr (pointer to scalar): Global memory scalar base pointer to store to.
70
+ offsets (tensor): Offsets tensor for the store operation.
71
+ mask (tensor, optional): Mask tensor for predicated store. Defaults to None.
72
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
73
+ """
74
+ _verify_buffer_ops(ptr, offsets, mask)
75
+
76
+ if mask is not None:
77
+ offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
78
+
79
+ mask = mask.handle if mask is not None else ir.value()
80
+ cache_modifier = _semantic._str_to_store_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
81
+
82
+ _semantic.builder.create_buffer_store(stored_value.handle, ptr.handle, offsets.handle, mask, cache_modifier)
83
+
84
+
85
+ @builtin
86
+ def mfma(a, b, acc, _semantic: GluonSemantic = None):
87
+ """
88
+ Computes matrix-multiplication of a * b + acc using AMD native matrix core units.
89
+ Args:
90
+ a (tensor): The first operand of mfma.
91
+ b (tensor): The second operand of mfma.
92
+ acc (tensor): The accumulator tensor.
93
+ """
94
+ assert acc is not None, "acc is required"
95
+ ret_type = acc.type
96
+ acc = ttgl._unwrap_if_constexpr(acc)
97
+
98
+ handle = _semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None,
99
+ out_dtype=acc.dtype).handle
100
+ return ttgl.tensor(handle, ret_type)