triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -0,0 +1,387 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, Tuple, List, TYPE_CHECKING
3
+
4
+ from dataclasses import dataclass
5
+ from triton.runtime.jit import constexpr_function
6
+ from triton.experimental.gluon.language import _core as ttgl
7
+ from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
8
+ from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta
9
+ from triton.experimental.gluon.language._semantic import _check
10
+
11
+ from . import tma
12
+ from ..hopper import fence_async_shared, mbarrier
13
+ from ..ampere import async_copy
14
+
15
+ from triton._C.libtriton import ir
16
+ if TYPE_CHECKING:
17
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
18
+ from ..._semantic import GluonSemantic
19
+
20
+ __all__ = [
21
+ "allocate_tensor_memory",
22
+ "async_copy",
23
+ "fence_async_shared",
24
+ "get_tmem_32x32b_reg_layout",
25
+ "mbarrier",
26
+ "tensor_memory_descriptor",
27
+ "TensorMemoryLayout",
28
+ "tma",
29
+ ]
30
+
31
+
32
+ @dataclass(frozen=True, eq=True)
33
+ class TensorMemoryLayout:
34
+ """
35
+ Describes the layout for tensor memory in Blackwell architecture.
36
+
37
+ Args:
38
+ block (Tuple[int, int]): Tiling block dimensions (M/rows, N/cols).
39
+ unpacked (bool): For sub-32 bit elements, whether they are unpacked to 32 bits.
40
+ cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None.
41
+ """
42
+ block: Tuple[int, int]
43
+ unpacked: bool
44
+ cta_split_num: Optional[Tuple[int, int]] = None
45
+
46
+ def __post_init__(self):
47
+ assert len(self.block) == 2
48
+ assert self.cta_split_num is None or len(self.cta_split_num) == 2
49
+
50
+ def _to_ir(self, builder):
51
+ cta_split_num = self.cta_split_num or [1, 1]
52
+ return builder.get_tensor_memory_layout(
53
+ self.block,
54
+ self.unpacked,
55
+ cta_split_num,
56
+ )
57
+
58
+ def mangle(self) -> str:
59
+ block_str = f"{self.block[0]}x{self.block[1]}"
60
+ unpacked_str = "U" if self.unpacked else "P"
61
+ cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
62
+ return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
63
+
64
+
65
+ @dataclass(frozen=True, eq=True)
66
+ class TensorMemoryScalesLayout:
67
+ """
68
+ Describes the layout for tensor memory scales in Blackwell architecture.
69
+
70
+ Args:
71
+ cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None.
72
+ """
73
+ cta_split_num: Optional[Tuple[int, int]] = None
74
+
75
+ def __post_init__(self):
76
+ assert self.cta_split_num is None or len(self.cta_split_num) == 2
77
+
78
+ def _to_ir(self, builder):
79
+ cta_split_num = self.cta_split_num or [1, 1]
80
+ return builder.get_tensor_memory_scales_layout(cta_split_num, )
81
+
82
+ def mangle(self) -> str:
83
+ cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
84
+ return f"TLS{cta_split_str}TLS"
85
+
86
+
87
+ @constexpr_function
88
+ def _cdiv(x, div):
89
+ return (x + div - 1) // div
90
+
91
+
92
+ @constexpr_function
93
+ def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None):
94
+ """Returns a BlockedLayout compatible with load/store on tensor memory with the 32x32b instruction variant.
95
+ """
96
+ assert len(shape) == 2, "expected a 2D tensor"
97
+ assert num_warps in [4, 8], "expected 4 or 8 warps"
98
+
99
+ shape_per_cta = _get_shape_per_cta(shape, cta_split_num)
100
+ blocks_per_tile = [shape_per_cta[0] // M, shape_per_cta[1] // N]
101
+ num_blocks = blocks_per_tile[0] * blocks_per_tile[1]
102
+
103
+ num_warp_groups = num_warps // 4
104
+ if M == 64:
105
+ threads_per_warp = [16, 2]
106
+ if num_blocks == 1:
107
+ size_per_thread = [1, _cdiv(N, num_warp_groups * 2)]
108
+ warps_per_cta = [4, num_warp_groups]
109
+ else:
110
+ size_per_thread = [1, _cdiv(N, 2)]
111
+ warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)]
112
+ warps_per_cta.append(_cdiv(num_warp_groups, warps_per_cta[0] // 4))
113
+ else:
114
+ if shape[0] > 128:
115
+ size_per_thread = [1, N]
116
+ threads_per_warp = [32, 1]
117
+ warps_per_cta = [4 * num_warp_groups, 1]
118
+ else:
119
+ size_per_thread = [1, _cdiv(N, num_warp_groups)]
120
+ threads_per_warp = [32, 1]
121
+ warps_per_cta = [4, num_warp_groups]
122
+ return BlockedLayout(
123
+ size_per_thread=size_per_thread,
124
+ threads_per_warp=threads_per_warp,
125
+ warps_per_cta=warps_per_cta,
126
+ order=[0, 1],
127
+ ctas_per_cga=ctas_per_cga,
128
+ cta_split_num=cta_split_num,
129
+ cta_order=cta_order,
130
+ )
131
+
132
+
133
+ class tensor_memory_descriptor_type(base_type):
134
+
135
+ def __init__(self, element_ty, shape, layout, alloc_shape):
136
+ self.element_ty = element_ty
137
+ self.shape = shape
138
+ self.layout = layout
139
+ self.alloc_shape = alloc_shape
140
+ assert isinstance(layout, TensorMemoryLayout) or isinstance(layout, TensorMemoryScalesLayout)
141
+
142
+ def to_ir(self, builder: GluonOpBuilder) -> None:
143
+ return builder.get_tensor_mem_desc_ty(
144
+ self.element_ty.to_ir(builder),
145
+ self.shape,
146
+ self.layout._to_ir(builder),
147
+ self.alloc_shape,
148
+ )
149
+
150
+ def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[tensor_memory_descriptor, int]:
151
+ value = tensor_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
152
+ return value, cursor + 1
153
+
154
+ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
155
+ out.append(self.to_ir(builder))
156
+
157
+ def __str__(self) -> str:
158
+ return f"tensor_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>"
159
+
160
+ def __eq__(self, other) -> bool:
161
+ return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
162
+ and self.alloc_shape == other.alloc_shape)
163
+
164
+ def __neq__(self, other) -> bool:
165
+ return not (self == other)
166
+
167
+ def mangle(self) -> str:
168
+ shape_str = "_".join([str(s) for s in self.shape])
169
+ return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
170
+
171
+
172
+ class tensor_memory_descriptor(base_value):
173
+ """
174
+ Represents a tensor memory descriptor handle for Tensor Core Gen5 operations.
175
+ """
176
+
177
+ def __init__(self, handle, element_ty, shape, layout, alloc_shape):
178
+ self.handle = handle
179
+ self.type = tensor_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
180
+
181
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
182
+ handles.append(self.handle)
183
+
184
+ @property
185
+ def dtype(self):
186
+ return self.type.element_ty
187
+
188
+ @property
189
+ def shape(self):
190
+ return self.type.shape
191
+
192
+ @property
193
+ def rank(self):
194
+ return len(self.shape)
195
+
196
+ @property
197
+ def layout(self):
198
+ return self.type.layout
199
+
200
+ def __str__(self) -> str:
201
+ return str(self.type)
202
+
203
+ @builtin
204
+ def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor:
205
+ """
206
+ Load a tensor from tensor memory.
207
+
208
+ Args:
209
+ layout (DistributedLayout): Destination layout of the tensor.
210
+
211
+ Returns:
212
+ tensor: A distributed tensor containing the loaded data.
213
+ """
214
+ layout = _unwrap_if_constexpr(layout)
215
+ ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout)
216
+ builder = _semantic.builder
217
+ handle = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle)
218
+ return ttgl.tensor(handle, ret_ty)
219
+
220
+ @builtin
221
+ def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
222
+ """
223
+ Store a tensor into tensor memory.
224
+
225
+ Args:
226
+ value (tensor): The tensor to store.
227
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
228
+ """
229
+ pred = _unwrap_if_constexpr(pred)
230
+ pred = _semantic.to_tensor(pred)
231
+ assert value.shape == self.shape, f"source shape {value.shape} does not match destination shape {self.shape}"
232
+ assert value.dtype == self.dtype, f"source dtype {value.dtype} does not match destination dtype {self.dtype}"
233
+ _semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
234
+
235
+ @builtin
236
+ def slice(self, start, length, _semantic: GluonSemantic) -> None:
237
+ """
238
+ Create a slice of the tensor memory descriptor along the last dimension.
239
+
240
+ Args:
241
+ start (int): The starting index for subslice.
242
+ length (int): The length of the subslice.
243
+
244
+ Returns:
245
+ tensor_memory_descriptor: Descriptor for the subslice.
246
+ """
247
+ start = _unwrap_if_constexpr(start)
248
+ length = _unwrap_if_constexpr(length)
249
+ _check(isinstance(start, int), lambda: "start must be a constant int")
250
+ _check(isinstance(length, int), lambda: "length must be a constant int")
251
+ shape = self.shape[:-1] + [length]
252
+ layout = self.type.layout
253
+ layout = TensorMemoryLayout((layout.block[0], min(layout.block[1], length)), layout.unpacked,
254
+ layout.cta_split_num)
255
+ ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
256
+ builder = _semantic.builder
257
+ ret.handle = builder.create_tmem_subslice(ret.type.to_ir(builder), self.handle, start)
258
+ return ret
259
+
260
+ @builtin
261
+ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
262
+ """
263
+ Create a subview of tensor memory by indexing the first dimension.
264
+
265
+ Args:
266
+ index (tensor): The index tensor for the subview.
267
+
268
+ Returns:
269
+ tensor_memory_descriptor: Descriptor for the indexed subview.
270
+ """
271
+ index = _semantic.to_tensor(index)
272
+ builder = _semantic.builder
273
+ shape = self.shape[1:]
274
+ layout = self.layout
275
+ ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
276
+ ret.handle = builder.create_memdesc_index(ret.type.to_ir(builder), self.handle, index.handle)
277
+ return ret
278
+
279
+ @builtin
280
+ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
281
+ """
282
+ Reinterpret tensor memory descriptor with a new dtype, shape, and layout.
283
+
284
+ Args:
285
+ dtype (dtype): The new data type.
286
+ shape (Sequence[int]): The new shape.
287
+ layout (TensorMemoryLayout): The new layout.
288
+
289
+ Returns:
290
+ tensor_memory_descriptor: Descriptor with updated type and layout.
291
+ """
292
+ dtype = _unwrap_if_constexpr(dtype)
293
+ shape = [_unwrap_if_constexpr(s) for s in shape]
294
+ layout = _unwrap_if_constexpr(layout)
295
+
296
+ ty = tensor_memory_descriptor_type(dtype, shape, layout, shape)
297
+ handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle)
298
+ return tensor_memory_descriptor(handle, **ty.__dict__)
299
+
300
+
301
+ @builtin
302
+ def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None):
303
+ """
304
+ Allocate tensor memory.
305
+
306
+ Args:
307
+ element_ty (dtype): The element data type.
308
+ shape (Sequence[int]): The descriptor shape.
309
+ layout (TensorMemoryLayout): The layout of the tensor memory.
310
+ value (tensor, optional): Initial tensor to copy. Defaults to None.
311
+
312
+ Returns:
313
+ tensor_memory_descriptor: Descriptor for the allocated memory.
314
+ """
315
+ element_ty = _unwrap_if_constexpr(element_ty)
316
+ shape = _unwrap_if_constexpr(shape)
317
+ layout = _unwrap_if_constexpr(layout)
318
+ value = value.handle if value is not None else None
319
+
320
+ ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape)
321
+ builder = _semantic.builder
322
+ handle = builder.create_tmem_alloc(ty.to_ir(builder), value)
323
+ return tensor_memory_descriptor(handle, element_ty, shape, layout, shape)
324
+
325
+
326
+ @builtin
327
+ def tcgen05_copy(src, dst, _semantic=None):
328
+ """
329
+ Start an asynchronous copy from shared memory to tensor memory.
330
+
331
+ WARNING: The current semantics of the instruction are not well defined and
332
+ the API will change in the future. Use at your own risk.
333
+
334
+ Args:
335
+ src (shared_memory_descriptor): Shared memory to copy from.
336
+ dst (tensor_memory_descriptor): Tensor memory to copy to.
337
+ """
338
+ assert isinstance(src, ttgl.shared_memory_descriptor), "source must be a shared memory descriptor"
339
+ assert isinstance(dst, tensor_memory_descriptor), "destination must be a tensor memory descriptor"
340
+ _semantic.builder.create_tmem_copy(src.handle, dst.handle)
341
+
342
+
343
+ @builtin
344
+ def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _semantic=None):
345
+ """
346
+ Emit a 5th generation TensorCore MMA instruction.
347
+ acc = a * b + (acc if use_acc else 0)
348
+
349
+ Args:
350
+ a (shared_memory_descriptor): Left hand side operand in shared memory.
351
+ b (shared_memory_descriptor or tensor_memory_descriptor): Right hand side operand in shared or tensor memory.
352
+ acc (tensor_memory_descriptor): Accumulator value in tensor memory (mutated).
353
+ use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
354
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
355
+ mbarriers (Sequence[shared_memory_descriptor], optional): Barriers to signal when the operation is complete. If None, mma is synchronous. Defaults to None.
356
+ mbarrier_preds (Sequence[bool], optional): Predicates for barriers. Defaults to None.
357
+ """
358
+ use_acc = _semantic.to_tensor(use_acc)
359
+ pred = _semantic.to_tensor(pred)
360
+
361
+ if mbarriers is None:
362
+ assert mbarrier_preds is None
363
+ mbarriers = []
364
+ mbarrier_preds = []
365
+ else:
366
+ mbarriers = [bar.handle for bar in mbarriers]
367
+ if mbarrier_preds is None:
368
+ true = _semantic.to_tensor(True)
369
+ mbarrier_preds = [true.handle] * len(mbarriers)
370
+ else:
371
+ mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False)
372
+
373
+ _semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers,
374
+ mbarrier_preds)
375
+
376
+
377
+ @builtin
378
+ def tcgen05_commit(barrier, _semantic=None):
379
+ """
380
+ This instruction causes the provided mbarrier to be arrived-on with a count
381
+ of 1 when all async tcgen05 MMA and copy instructions previously issued by
382
+ the thread are complete.
383
+
384
+ Args:
385
+ barrier (shared_memory_descriptor): The barrier to track completion of tcgen05 MMA and copy instructions.
386
+ """
387
+ _semantic.builder.create_tcgen05_commit(barrier.handle)
@@ -0,0 +1,52 @@
1
+ from triton.experimental.gluon.language._core import builtin
2
+ from triton.experimental.gluon.language.nvidia.hopper.tma import (
3
+ async_copy_global_to_shared,
4
+ async_copy_shared_to_global,
5
+ store_wait,
6
+ tensor_descriptor,
7
+ tensor_descriptor_type,
8
+ )
9
+
10
+ __all__ = [
11
+ "async_gather",
12
+ "async_scatter",
13
+ "async_copy_global_to_shared",
14
+ "async_copy_shared_to_global",
15
+ "store_wait",
16
+ "tensor_descriptor",
17
+ "tensor_descriptor_type",
18
+ ]
19
+
20
+
21
+ @builtin
22
+ def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None):
23
+ """
24
+ Asynchronously gather elements from global memory to shared memory using TMA.
25
+
26
+ Args:
27
+ tensor_desc (tensor_descriptor): The tensor descriptor.
28
+ x_offsets (tensor): 1D tensor of X offsets.
29
+ y_offset (int): Scalar Y offset.
30
+ barrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete.
31
+ result (tensor_memory_descriptor): Result shared memory, must have NVMMASharedLayout.
32
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
33
+ """
34
+ pred = _semantic.to_tensor(pred)
35
+ y_offset = _semantic.to_tensor(y_offset)
36
+ _semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle,
37
+ result.handle, pred.handle)
38
+
39
+
40
+ @builtin
41
+ def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None):
42
+ """
43
+ Asynchronously scatter elements from shared memory to global memory using TMA.
44
+
45
+ Args:
46
+ tensor_desc (tensor_descriptor): The tensor descriptor.
47
+ x_offsets (tensor): 1D tensor of X offsets.
48
+ y_offset (int): Scalar Y offset.
49
+ src (tensor_memory_descriptor): The source data, must be in NVMMASharedLayout.
50
+ """
51
+ y_offset = _semantic.to_tensor(y_offset)
52
+ _semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle)
@@ -0,0 +1,132 @@
1
+ from __future__ import annotations
2
+ from triton.compiler.code_generator import unflatten_ir_values
3
+ from ..ampere import async_copy
4
+ from . import mbarrier, tma
5
+ from ... import _core
6
+
7
+ from typing import List, Tuple, TYPE_CHECKING
8
+ if TYPE_CHECKING:
9
+ from triton._C.libtriton import ir
10
+
11
+ __all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
12
+
13
+
14
+ @_core.builtin
15
+ def fence_async_shared(cluster=False, _semantic=None):
16
+ """
17
+ Issue a fence to complete asynchronous shared memory operations.
18
+
19
+ Args:
20
+ cluster (bool): Whether to fence across cluster. Defaults to False.
21
+ """
22
+ cluster = _core._unwrap_if_constexpr(cluster)
23
+ _semantic.builder.create_fence_async_shared(cluster)
24
+
25
+
26
+ class warpgroup_mma_accumulator_type(_core.base_type):
27
+ tensor_type: _core.dtype
28
+
29
+ def __init__(self, tensor_type: _core.dtype):
30
+ self.tensor_type = tensor_type
31
+
32
+ def __str__(self) -> str:
33
+ return f"warpgroup_mma_accumulator<{self.tensor_type}>"
34
+
35
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[warpgroup_mma_accumulator, int]:
36
+ return warpgroup_mma_accumulator(handles[cursor], self.tensor_type), cursor + 1
37
+
38
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
39
+ self.tensor_type._flatten_ir_types(builder, out)
40
+
41
+ def __eq__(self, other) -> bool:
42
+ return type(self) is type(other) and self.tensor_type == other.tensor_type
43
+
44
+ def mangle(self) -> str:
45
+ return f"FT{self.tensor_type.mangle()}FT"
46
+
47
+
48
+ class warpgroup_mma_accumulator(_core.base_value):
49
+ handle: ir.value
50
+ type: warpgroup_mma_accumulator_type
51
+
52
+ def __init__(self, handle, tensor_type: _core.dtype):
53
+ self.handle = handle
54
+ self.type = warpgroup_mma_accumulator_type(tensor_type)
55
+
56
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
57
+ handles.append(self.handle)
58
+
59
+
60
+ @_core.builtin
61
+ def warpgroup_mma_init(value, _semantic):
62
+ assert isinstance(value, _core.tensor)
63
+ return warpgroup_mma_accumulator(value.handle, value.type)
64
+
65
+
66
+ @_core.builtin
67
+ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=None, is_async=False,
68
+ _semantic=None):
69
+ """
70
+ Perform warpgroup MMA (Tensor Core) operations.
71
+ acc = a * b + (acc if use_acc else 0)
72
+
73
+ Args:
74
+ a (tensor or shared_memory_descriptor): Left hand side operand.
75
+ b (shared_memory_descriptor): Right hand side operand.
76
+ acc (tensor): Accumulator tensor.
77
+ use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
78
+ precision (str, optional): Dot input precision. Defaults to builder default.
79
+ max_num_imprecise_acc (int): Max imprecise accumulations. Used for fp8 -> fp32 dot. Determines how many accumulation are done in limited precision. Defaults to None, which means no upcasting is done.
80
+ is_async (bool): Whether operation is asynchronous. Defaults to False.
81
+
82
+ Returns:
83
+ tensor or warpgroup_mma_accumulator: Returns the result if synchronous, or a token to load the value once computed if asynchronous.
84
+ """
85
+ use_acc = _semantic.to_tensor(use_acc)
86
+
87
+ if precision is None:
88
+ precision = _semantic.builder.options.default_dot_input_precision
89
+
90
+ precision = _semantic._str_to_dot_input_precision(precision)
91
+
92
+ K = a.type.shape[-1]
93
+ if max_num_imprecise_acc is None:
94
+ if a.dtype.is_fp8() and b.dtype.is_fp8():
95
+ max_num_imprecise_acc = _semantic.builder.options.max_num_imprecise_acc_default
96
+ else:
97
+ max_num_imprecise_acc = 0
98
+ else:
99
+ if a.dtype.is_fp8() and b.dtype.is_fp8() and max_num_imprecise_acc > K:
100
+ raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
101
+
102
+ max_num_imprecise_acc = _core._unwrap_if_constexpr(max_num_imprecise_acc)
103
+ is_async = _core._unwrap_if_constexpr(is_async)
104
+
105
+ handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
106
+ max_num_imprecise_acc, is_async)
107
+ tensor_ty = acc.type.tensor_type if isinstance(acc, warpgroup_mma_accumulator) else acc.type
108
+ if is_async:
109
+ return warpgroup_mma_accumulator(handle, tensor_ty)
110
+ else:
111
+ return _core.tensor(handle, tensor_ty)
112
+
113
+
114
+ @_core.builtin
115
+ def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
116
+ """
117
+ Wait until `num_outstanding` or less warpgroup MMA operations are in-flight.
118
+
119
+ Args:
120
+ num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0.
121
+ deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished.
122
+ """
123
+ if deps is None:
124
+ raise ValueError("warpgroup_mma_wait deps must be given")
125
+ deps_handles = [x.handle for x in deps] if deps is not None else []
126
+ num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
127
+ results = _semantic.builder.create_warpgroup_mma_wait(deps_handles, num_outstanding)
128
+ result_types = [dep.type.tensor_type if isinstance(dep, warpgroup_mma_accumulator) else dep.type for dep in deps]
129
+ results = unflatten_ir_values(results, result_types)
130
+ if len(deps) == 1:
131
+ return next(results)
132
+ return tuple(results)
@@ -0,0 +1,34 @@
1
+ from ..ampere.mbarrier import MBarrierLayout, init, invalidate, wait
2
+ from ..._core import _unwrap_if_constexpr, builtin
3
+
4
+ __all__ = ["arrive", "expect", "init", "invalidate", "MBarrierLayout", "wait"]
5
+
6
+
7
+ @builtin
8
+ def expect(mbarrier, bytes, pred=True, _semantic=None):
9
+ """
10
+ Expect a specific number of bytes being copied. When they are copied, the barrier is signaled.
11
+
12
+ Args:
13
+ mbarrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete.
14
+ bytes (int): Expected byte count.
15
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
16
+ """
17
+ bytes = _unwrap_if_constexpr(bytes)
18
+ pred = _semantic.to_tensor(pred)
19
+ _semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
20
+
21
+
22
+ @builtin
23
+ def arrive(mbarrier, *, count=1, pred=True, _semantic=None):
24
+ """
25
+ Arrive at an mbarrier with a specified count.
26
+
27
+ Args:
28
+ mbarrier (shared_memory_descriptor): Barrier to be signalled.
29
+ count (int): Count to arrive with. Defaults to 1.
30
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
31
+ """
32
+ count = _unwrap_if_constexpr(count)
33
+ pred = _semantic.to_tensor(pred)
34
+ _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)