triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,48 @@
1
+ from triton.experimental.gluon.language import _core as ttgl
2
+ from ..._core import builtin, float32
3
+ from ..._layouts import DotOperandLayout
4
+ from .._layouts import AMDMFMALayout
5
+ from ..cdna3 import * # NOQA: F403
6
+ from ..cdna3 import __all__ as __cdna3_all
7
+ from . import async_copy
8
+
9
+ __all__ = [*__cdna3_all, "async_copy", "mfma_scaled"]
10
+
11
+
12
+ @builtin
13
+ def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None):
14
+ """
15
+ AMD Scaled MFMA operation.
16
+
17
+ ```
18
+ c = a * a_scale @ b * b_scale + acc
19
+ ```
20
+
21
+ `a` and `b` use microscaling formats described in
22
+ "OCP Microscaling Formats (MX) Specification":
23
+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf.
24
+ Currently supported only on CDNA4 hardware.
25
+
26
+ Args:
27
+ a (tensor): The operand A to be multiplied.
28
+ a_scale (tensor): Scale factor for operand A.
29
+ a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`.
30
+ b (tensor): The operand B to be multiplied.
31
+ b_scale (tensor): Scale factor for operand B. Available formats: `e2m1`, `e4m3`, `e5m2`.
32
+ b_format (str): Format of the operand B.
33
+ acc (tensor): Accumulator tensor.
34
+ """
35
+ layout = acc.type.layout
36
+ assert isinstance(layout, AMDMFMALayout), "Expected layout to be an instance of AMDMFMALayout"
37
+ assert (isinstance(a.type.layout, DotOperandLayout) and a.type.layout.parent== layout), \
38
+ "Expected lhs layout to be a DotOperandLayout with parent matching MFMA layout"
39
+ assert (isinstance(b.type.layout, DotOperandLayout) and b.type.layout.parent == layout), \
40
+ "Expected rhs layout to be a DotOperandLayout with parent matching MFMA layout"
41
+
42
+ assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
43
+ assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
44
+
45
+ tensor = _semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, False, True, True, float32)
46
+
47
+ ret_ty = ttgl.distributed_type(tensor.dtype, tensor.shape, layout)
48
+ return ttgl.tensor(tensor.handle, ret_ty)
@@ -0,0 +1,151 @@
1
+ from ..._core import ir, builtin, _unwrap_if_constexpr
2
+ from ..._semantic import _check
3
+ from ..._layouts import BlockedLayout, SliceLayout
4
+ from ..cdna3 import _verify_buffer_ops
5
+
6
+ __all__ = [
7
+ "global_load_to_shared",
8
+ "buffer_load_to_shared",
9
+ "async_wait",
10
+ "load_shared_relaxed",
11
+ ]
12
+
13
+
14
+ @builtin
15
+ def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _semantic=None):
16
+ """
17
+ AMD global load to shared operation. This operation loads data directly
18
+ from global memory to shared memory without going through registers. It
19
+ happens asynchronously and requires a subsequent `async_wait` to ensure the
20
+ data is available in shared memory.
21
+ Compared to `buffer_load_to_shared`, it requires a tensor pointer which
22
+ supports 64-bit indexing range for each thread in a block, which gives more
23
+ flexibility, but at the cost of higher register pressure and no hardware
24
+ out-of-bound masking support. Prefer to use `buffer_load_to_shared` when
25
+ possible for better performance.
26
+
27
+ The underlying hardware instruction uses separate registers for global
28
+ memory address for each thread but the same register for local memory
29
+ address for the whole warp. Therefore, while using this operation
30
+ the following conditions must be met or lowering to LLVM will fail:
31
+
32
+ - For the `ptr` layout, size per thread * bits per element must be 128 or 32.
33
+ To get ideal performance, it is recommended to use 128 bits per element.
34
+ - Writes to `dest` must be coalesced.
35
+ - If `dest` is swizzled, it only can be swizzled within warp boundary.
36
+
37
+ Args:
38
+ dest (shared_memory_descriptor): Destination shared memory descriptor.
39
+ ptr (pointer tensor): Tensor of pointers to global memory to load from.
40
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
41
+ other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
42
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
43
+ """
44
+ _check(ptr.type.is_block(), lambda: "expected ptr to be a tensor")
45
+ _check(isinstance(ptr.type.layout, (BlockedLayout, SliceLayout)),
46
+ lambda: "expected ptr type layout to be BlockedLayout or SliceLayout")
47
+ _check(
48
+ dest.shape == ptr.shape, lambda:
49
+ f"expected dest shape to match pointer shape but got dest.shape = {dest.shape}, pointer.shape = {ptr.shape}")
50
+
51
+ mask = _unwrap_if_constexpr(mask)
52
+ if mask is not None:
53
+ ptr, mask = _semantic.broadcast_impl_value(ptr, mask)
54
+ other = _unwrap_if_constexpr(other)
55
+ if other is not None:
56
+ ptr, other = _semantic.broadcast_impl_value(ptr, other)
57
+
58
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
59
+ mask_handle = mask.handle if mask is not None else ir.value()
60
+ other_handle = other.handle if other is not None else ir.value()
61
+ _semantic.builder.create_async_copy_global_to_local(dest.handle, ptr.handle, mask_handle, other_handle,
62
+ cache_modifier, ir.EVICTION_POLICY.NORMAL, False)
63
+
64
+
65
+ @builtin
66
+ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modifier="", _semantic=None):
67
+ """
68
+ AMD buffer load to shared operation. Buffer load is similar to global load
69
+ but it accesses global memory via a scalar base pointer and a tensor of
70
+ 32-bit offsets instead of a tensor of pointers. This operation loads data
71
+ directly from global memory to shared memory without going through
72
+ registers. It happens asynchronously and requires a subsequent `async_wait`
73
+ to ensure the data is available in shared memory.
74
+ Compared to `global_load_to_shared`, it has better performance and also
75
+ supports hardware out-of-bound masking. But it strictly requires a
76
+ 32-bit offset instead of a 64-bit tensor pointer.
77
+
78
+ The underlying hardware instruction uses separate registers for global
79
+ memory address for each thread but the same register for local memory
80
+ address for the whole warp. Therefore, while using this operation
81
+ the following conditions must be met or lowering to LLVM will fail:
82
+
83
+ - For the `offsets` layout, size per thread * bits per element must be 128 or 32.
84
+ To get ideal performance, it is recommended to use 128 bits per element.
85
+ - Writes to `dest` must be coalesced.
86
+ - If `dest` is swizzled, it only can be swizzled within warp boundary.
87
+
88
+ Args:
89
+ dest (shared_memory_descriptor): Destination shared memory descriptor.
90
+ ptr (pointer to scalar): Global memory scalar base pointer to load from.
91
+ offsets (tensor): Offsets tensor for the load operation.
92
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
93
+ other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
94
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
95
+ """
96
+ _check(isinstance(offsets.type.layout, (BlockedLayout, SliceLayout)),
97
+ lambda: "expected offsets type layout to be BlockedLayout or SliceLayout")
98
+ _verify_buffer_ops(ptr, offsets, mask, other)
99
+
100
+ mask = _unwrap_if_constexpr(mask)
101
+ if mask is not None:
102
+ offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
103
+ other = _unwrap_if_constexpr(other)
104
+ if other is not None:
105
+ offsets, other = _semantic.broadcast_impl_value(offsets, other)
106
+
107
+ mask = mask.handle if mask is not None else ir.value()
108
+ other = other.handle if other is not None else ir.value()
109
+ stride = ir.value()
110
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
111
+
112
+ _semantic.builder.create_buffer_load_to_local(dest.handle, ptr.handle, offsets.handle, mask, other, stride,
113
+ cache_modifier)
114
+
115
+
116
+ @builtin
117
+ def async_wait(num_outstanding=0, _semantic=None):
118
+ """
119
+ Wait for outstanding memory operations, this includes normal load like
120
+ `load` and `buffer_load`, as well as direct load to shared memory
121
+ like `global_load_to_shared` and `buffer_load_to_shared`.
122
+ It will block until the number of outstanding memory operations is less than
123
+ or equal to `num_outstanding`.
124
+
125
+ Args:
126
+ num_outstanding (int): The number of outstanding operations to wait for. Defaults to 0.
127
+ """
128
+ num_outstanding = _unwrap_if_constexpr(num_outstanding)
129
+ _semantic.builder.create_async_wait_group(num_outstanding)
130
+
131
+
132
+ @builtin
133
+ def load_shared_relaxed(smem, layout, _semantic=None):
134
+ """
135
+ Load a tensor from shared memory with extra hints for the underlying
136
+ compiler to avoid emitting unnecessary waits before loading from the target
137
+ shared memory.
138
+
139
+ Args:
140
+ smem (shared_memory_descriptor): Shared memory descriptor to load from.
141
+ layout (DistributedLayout): The destination layout of the tensor.
142
+
143
+ Returns:
144
+ tensor: A Gluon tensor containing the loaded data.
145
+ """
146
+ SYNCED_VIA_WAIT_ATTR_NAME = "ttg.amdgpu.syncedViaAsyncWait"
147
+
148
+ layout = _unwrap_if_constexpr(layout)
149
+ ret = _semantic.shared_load(smem, layout)
150
+ ret.handle.set_attr(SYNCED_VIA_WAIT_ATTR_NAME, _semantic.builder.get_bool_attr(True))
151
+ return ret
@@ -0,0 +1,3 @@
1
+ from triton.language.extra import libdevice
2
+
3
+ __all__ = ["libdevice"]
@@ -0,0 +1,4 @@
1
+ from . import blackwell
2
+ from . import hopper
3
+
4
+ __all__ = ["blackwell", "hopper"]
@@ -0,0 +1,3 @@
1
+ from . import async_copy, mbarrier
2
+
3
+ __all__ = ["async_copy", "mbarrier"]
@@ -0,0 +1,74 @@
1
+ from ..._semantic import _check
2
+ from ..._core import _unwrap_if_constexpr, builtin
3
+ from triton._C.libtriton import ir
4
+
5
+ __all__ = [
6
+ "async_copy_global_to_shared",
7
+ "mbarrier_arrive",
8
+ "commit_group",
9
+ "wait_group",
10
+ ]
11
+
12
+
13
+ @builtin
14
+ def async_copy_global_to_shared(smem, pointer, mask=None, cache_modifier="", eviction_policy="", volatile=False,
15
+ _semantic=None):
16
+ """
17
+ Asynchronously copy elements from global memory to shared memory.
18
+
19
+ Args:
20
+ smem (shared_memory_descriptor): Destination shared memory descriptor.
21
+ pointer (tensor): Source pointer tensor.
22
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
23
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
24
+ eviction_policy (str): Eviction policy specifier. Defaults to "".
25
+ volatile (bool): Whether the load is volatile. Defaults to False.
26
+ """
27
+ mask = _unwrap_if_constexpr(mask)
28
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
29
+ eviction_policy = _semantic._str_to_eviction_policy(eviction_policy)
30
+ volatile = _unwrap_if_constexpr(volatile)
31
+ if mask is not None:
32
+ pointer, mask = _semantic.broadcast_impl_value(pointer, mask)
33
+ _check(
34
+ smem.shape == pointer.shape, lambda:
35
+ f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}"
36
+ )
37
+ mask_handle = mask.handle if mask is not None else ir.value()
38
+ _semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, ir.value(),
39
+ cache_modifier, eviction_policy, volatile)
40
+
41
+
42
+ @builtin
43
+ def mbarrier_arrive(mbarrier, increment_count=True, _semantic=None):
44
+ """
45
+ Arrive on the mbarrier once all outstanding async copies are complete.
46
+
47
+ Args:
48
+ mbarrier (shared_memory_descriptor): Barrier object to arrive on.
49
+ increment_count (bool): Whether to increment the arrival count. Defaults to True.
50
+ """
51
+ increment_count = _unwrap_if_constexpr(increment_count)
52
+ _semantic.builder.create_async_copy_mbarrier_arrive(mbarrier.handle, increment_count)
53
+
54
+
55
+ @builtin
56
+ def commit_group(_semantic=None):
57
+ """
58
+ Commit the current asynchronous copy group.
59
+
60
+ This finalizes a set of asynchronous copy operations.
61
+ """
62
+ _semantic.builder.create_async_commit_group()
63
+
64
+
65
+ @builtin
66
+ def wait_group(num_outstanding=0, _semantic=None):
67
+ """
68
+ Wait for outstanding asynchronous copy group operations.
69
+
70
+ Args:
71
+ num_outstanding (int): Wait until `num_outstanding` or less async copy groups in-flight. Defaults to 0.
72
+ """
73
+ num_outstanding = _unwrap_if_constexpr(num_outstanding)
74
+ _semantic.builder.create_async_wait_group(num_outstanding)
@@ -0,0 +1,80 @@
1
+ from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
2
+ from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
3
+
4
+ __all__ = ["arrive", "init", "invalidate", "MBarrierLayout", "wait"]
5
+
6
+
7
+ class MBarrierLayout(SwizzledSharedLayout):
8
+ """
9
+ Layout for mbarrier synchronization in Ampere and later architectures.
10
+
11
+ Args:
12
+ ctas_per_cga (int): CTAs per CGA grouping. Defaults to 1.
13
+ cta_split_num (int): CTA split factor. Defaults to 1.
14
+ """
15
+
16
+ def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
17
+ super().__init__(
18
+ vec=1,
19
+ per_phase=1,
20
+ max_phase=1,
21
+ order=[0],
22
+ ctas_per_cga=[ctas_per_cga],
23
+ cta_split_num=[cta_split_num],
24
+ cta_order=[0],
25
+ )
26
+
27
+
28
+ @builtin
29
+ def init(mbarrier, count, _semantic=None):
30
+ """
31
+ Initialize an mbarrier with a specified count.
32
+
33
+ Args:
34
+ mbarrier (shared_memory_descriptor): The barrier object to initialize.
35
+ count (int): The initial count for the barrier.
36
+ """
37
+ count = _unwrap_if_constexpr(count)
38
+ _semantic.builder.create_mbarrier_init(mbarrier.handle, count)
39
+
40
+
41
+ @builtin
42
+ def invalidate(mbarrier, _semantic=None):
43
+ """
44
+ Invalidate an mbarrier, resetting its state.
45
+
46
+ Args:
47
+ mbarrier (shared_memory_descriptor): The barrier object to invalidate.
48
+ """
49
+ _semantic.builder.create_mbarrier_inval(mbarrier.handle)
50
+
51
+
52
+ @builtin
53
+ def wait(mbarrier, phase, pred=True, deps=(), _semantic=None):
54
+ """
55
+ Wait until the mbarrier object completes its current phase.
56
+
57
+ Args:
58
+ mbarrier (shared_memory_descriptor): The barrier object to wait on.
59
+ phase (int): The phase index to wait for.
60
+ pred (bool): Predicate. Operation is skipped if predicate is False. Defaults to True.
61
+ deps (Sequence[shared_memory_descriptor]): Dependent allocations barrier is waiting on. Used to track liveness of dependent allocations. Defaults to ().
62
+ """
63
+ phase = _semantic.to_tensor(phase)
64
+ pred = _semantic.to_tensor(pred)
65
+ deps = [x.handle for x in deps]
66
+ _semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
67
+
68
+
69
+ @builtin
70
+ def arrive(mbarrier, *, pred=True, _semantic=None):
71
+ """
72
+ Arrive on an mbarrier, signaling that a thread has reached the barrier.
73
+
74
+ Args:
75
+ mbarrier (shared_memory_descriptor): The barrier object to arrive on.
76
+ pred (bool): Predicate. Operation is skipped if predicate is False. Defaults to True.
77
+ """
78
+ count = 1
79
+ pred = _semantic.to_tensor(pred)
80
+ _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)