triton-windows 3.3.0.post19__cp311-cp311-win_amd64.whl → 3.4.0.post20__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
@@ -0,0 +1,287 @@
1
+ from typing import Sequence, List, TypeVar, Tuple, Callable
2
+ from triton.language.semantic import TritonSemantic
3
+ from . import _core as ttgl
4
+ from ._layouts import SliceLayout
5
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
6
+ from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
7
+
8
+ TensorTy = TypeVar("TensorTy")
9
+
10
+
11
+ def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
12
+ if not cond:
13
+ raise category(msg_fn())
14
+
15
+
16
+ class GluonSemantic(TritonSemantic[TensorTy]):
17
+ tensor = ttgl.tensor
18
+ lang = ttgl
19
+
20
+ builder: GluonOpBuilder
21
+
22
+ def __init__(self, builder: GluonOpBuilder):
23
+ self.builder = builder
24
+
25
+ def _wrap_tensor_infer_layout(self, tensor):
26
+ ty = ttgl.distributed_type(tensor.type.scalar, tensor.shape,
27
+ self.builder.get_gluon_layout_from_tensor(tensor.handle))
28
+ return self.tensor(tensor.handle, ty)
29
+
30
+ def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
31
+ if len(lhs_shape) != len(rhs_shape):
32
+ raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")
33
+
34
+ ret_shape = []
35
+ for i, left in enumerate(lhs_shape):
36
+ right = rhs_shape[i]
37
+ if left == 1:
38
+ ret_shape.append(right)
39
+ elif (right == 1) or (right == left):
40
+ ret_shape.append(left)
41
+ else:
42
+ raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
43
+ "at index " + str(i) + ": " + str(left) + " and " + str(right))
44
+ return ret_shape
45
+
46
+ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
47
+ dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape]
48
+ dst_shape.insert(axis, 1)
49
+
50
+ if axis < 0:
51
+ axis += len(input.shape)
52
+
53
+ _check(isinstance(input.type, ttgl.distributed_type),
54
+ lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
55
+ layout = input.type.layout
56
+ _check(isinstance(layout, SliceLayout),
57
+ lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
58
+ _check(layout.dim == axis,
59
+ lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
60
+
61
+ ret_ty = ttgl.distributed_type(input.type.scalar, dst_shape, layout.parent)
62
+ handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder))
63
+ return self.tensor(handle, ret_ty)
64
+
65
+ def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
66
+ a, b = self.broadcast_impl_value(a, b)
67
+ _check(a.shape != [], "Cannot join scalars in gluon")
68
+ value = super().join(a, b)
69
+ return self._wrap_tensor_infer_layout(value)
70
+
71
+ def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
72
+ lhs, rhs = super().split(a)
73
+ return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs)
74
+
75
+ def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
76
+ value = super().permute(input, dims)
77
+ return self._wrap_tensor_infer_layout(value)
78
+
79
+ def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
80
+ _check(isinstance(input.type, ttgl.distributed_type),
81
+ lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
82
+ src_shape = input.type.get_block_shapes()
83
+ _check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
84
+ if shape == src_shape:
85
+ return input
86
+ for i, item in enumerate(src_shape):
87
+ if shape[i] != item and item != 1:
88
+ raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
89
+ f" must match the existing size ({item}) at non-singleton dimension"
90
+ f" {i}: {src_shape}, {shape}")
91
+ ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout)
92
+ handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder))
93
+ return self.tensor(handle, ret_ty)
94
+
95
+ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
96
+ lhs_ty = lhs.type
97
+ rhs_ty = rhs.type
98
+
99
+ if not lhs_ty.is_block() or not rhs_ty.is_block():
100
+ return super().broadcast_impl_value(lhs, rhs)
101
+
102
+ _check(isinstance(lhs_ty, ttgl.distributed_type),
103
+ lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}")
104
+ _check(isinstance(rhs_ty, ttgl.distributed_type),
105
+ lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}")
106
+
107
+ lhs_shape = lhs_ty.get_block_shapes()
108
+ rhs_shape = rhs_ty.get_block_shapes()
109
+ ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
110
+ if lhs_ty.layout != rhs_ty.layout:
111
+ raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
112
+
113
+ lhs = self.broadcast_impl_shape(lhs, ret_shape)
114
+ rhs = self.broadcast_impl_shape(rhs, ret_shape)
115
+ return lhs, rhs
116
+
117
+ def arange(self, start, end, layout):
118
+ shape = [end - start]
119
+ ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
120
+ return super().arange(start, end, ret_ty=ret_ty)
121
+
122
+ def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
123
+ _check(not can_reorder, "can_reorder is not supported in gluon")
124
+ value = super().reshape(input, dst_shape, can_reorder)
125
+ return self._wrap_tensor_infer_layout(value)
126
+
127
+ def splat(self, value, shape, layout):
128
+ ret_ty = ttgl.distributed_type(value.dtype, shape, layout)
129
+ handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle)
130
+ return ttgl.tensor(handle, ret_ty)
131
+
132
+ def full(self, shape, value, dtype, layout):
133
+ scalar = self.make_scalar(value, dtype)
134
+ return self.splat(scalar, shape, layout)
135
+
136
+ def convert_layout(self, value, layout):
137
+ ty = value.type
138
+ _check(isinstance(ty, ttgl.distributed_type),
139
+ lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
140
+ ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
141
+ handle = self.builder.create_convert_layout(ret_ty.to_ir(self.builder), value.handle)
142
+ return ttgl.tensor(handle, ret_ty)
143
+
144
+ def allocate_shared(self, element_ty, shape, layout, value):
145
+ ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape)
146
+ if value is not None:
147
+ handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle)
148
+ else:
149
+ handle = self.builder.create_local_alloc(ty.to_ir(self.builder))
150
+ return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape)
151
+
152
+ def shared_load(self, mem_desc, layout):
153
+ ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout)
154
+ handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle)
155
+ return ttgl.tensor(handle, ret_ty)
156
+
157
+ def shared_store(self, mem_desc, value):
158
+ self.builder.create_local_store(mem_desc.handle, value.handle)
159
+
160
+ def shared_dealloc(self, mem_desc):
161
+ self.builder.create_local_dealloc(mem_desc.handle)
162
+
163
+ def _memdesc_subview(self, mem_desc, offsets, shape):
164
+ layout = mem_desc.layout
165
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
166
+ builder = self.builder
167
+ handle = builder.create_memdesc_subview(ty.to_ir(builder), mem_desc.handle, offsets)
168
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
169
+
170
+ def memdesc_slice(self, mem_desc, start, length, dim):
171
+ offsets = [self.builder.get_int32(0)] * mem_desc.rank
172
+ offsets[dim] = self.to_tensor(start).handle
173
+ shape = list(mem_desc.shape)
174
+ shape[dim] = length
175
+ return self._memdesc_subview(mem_desc, offsets, shape)
176
+
177
+ def memdesc_index(self, mem_desc, index):
178
+ shape = mem_desc.shape[1:]
179
+ offsets = [self.builder.get_int32(0)] * mem_desc.rank
180
+ offsets[0] = self.to_tensor(index).handle
181
+ return self._memdesc_subview(mem_desc, offsets, shape)
182
+
183
+ def memdesc_trans(self, mem_desc, order):
184
+ assert len(order) == len(
185
+ mem_desc.shape), f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match"
186
+
187
+ shape = [mem_desc.shape[i] for i in order]
188
+ alloc_shape = mem_desc.type.alloc_shape
189
+ new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
190
+ new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
191
+
192
+ handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
193
+ layout = self.builder.get_gluon_layout_from_memdesc(handle)
194
+ return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
195
+ alloc_shape=new_alloc_shape, layout=layout)
196
+
197
+ def memdesc_reshape(self, mem_desc, shape, layout):
198
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
199
+ handle = self.builder.create_memdesc_reshape(ty.to_ir(self.builder), mem_desc.handle)
200
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
201
+
202
+ def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
203
+ ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
204
+ handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle)
205
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
206
+
207
+ def wrap_tensor(self, x, scalar_ty, ret_shape, layout):
208
+ if ret_shape:
209
+ res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout)
210
+ else:
211
+ res_ty = scalar_ty
212
+ return self.tensor(x, res_ty)
213
+
214
+ @staticmethod
215
+ def _check_same_layout(xs):
216
+ for x in xs:
217
+ _check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}")
218
+ layouts = [x.type.layout for x in xs]
219
+ l0 = layouts[0]
220
+ _check(all(l == l0 for l in layouts[1:]),
221
+ lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
222
+
223
+ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
224
+ _check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
225
+ # get result shape
226
+ shape = inputs[0].type.shape
227
+ rank = len(shape)
228
+ _check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
229
+ self._check_same_layout(inputs)
230
+ ret_shape = [s for i, s in enumerate(shape) if i != axis]
231
+ ret_layout = SliceLayout(axis, inputs[0].type.layout)
232
+ assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
233
+
234
+ reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
235
+ region_builder_fn(reduce_op)
236
+ assert reduce_op.verify()
237
+
238
+ return tuple(
239
+ self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
240
+ for i in range(len(inputs)))
241
+
242
+ def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
243
+ worker_num_regs: Sequence[int], generator):
244
+ num_partitions = len(worker_partitions)
245
+ assert num_partitions == len(
246
+ worker_num_warps
247
+ ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"
248
+ assert num_partitions == len(
249
+ worker_num_regs
250
+ ), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts"
251
+
252
+ builder = self.builder
253
+ insert_pt = builder.get_insertion_point()
254
+
255
+ # Emit the default partition to get the result types.
256
+ default_block = builder.new_block()
257
+ builder.set_insertion_point_to_start(default_block)
258
+ default_results = generator.call_JitFunction(default_partition, args, kwargs={})
259
+ mlir_results = []
260
+ if default_results is not None:
261
+ mlir_results = flatten_values_to_ir(default_results)
262
+ builder.create_warp_yield(mlir_results)
263
+ result_types = [r.get_type() for r in mlir_results]
264
+
265
+ # Create the warp specialize op.
266
+ builder.restore_insertion_point(insert_pt)
267
+ mlir_args = flatten_values_to_ir(args)
268
+ ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
269
+ ws_op.get_default_region().push_back(default_block)
270
+ ws_op.set_requested_registers(worker_num_regs)
271
+
272
+ # Emit the partition regions.
273
+ builder.create_block_with_parent(ws_op.get_partition_op_holder(), [])
274
+ partitions_op = builder.create_warp_specialize_partitions(num_partitions)
275
+ arg_types = [arg.get_type() for arg in mlir_args]
276
+ for i in range(num_partitions):
277
+ block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
278
+ block_args = [block.get_argument(j) for j in range(len(mlir_args))]
279
+ block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
280
+ generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
281
+ builder.create_warp_return()
282
+
283
+ builder.set_insertion_point_after(ws_op.get_operation())
284
+ mlir_results = [ws_op.get_result(i) for i in range(len(result_types))]
285
+ if default_results is None:
286
+ return
287
+ return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
@@ -0,0 +1,47 @@
1
+ # flake8: noqa
2
+ import triton
3
+ import triton.language.standard as tl_standard
4
+ from .._runtime import jit
5
+ from triton import knobs
6
+ from . import _core as ttgl
7
+
8
+ _IMPORT_FROM_TRITON = [
9
+ "sum",
10
+ "max",
11
+ "min",
12
+ "reduce_or",
13
+ "xor_sum",
14
+ ]
15
+
16
+ __all__ = [
17
+ "full_like",
18
+ "zeros",
19
+ "zeros_like",
20
+ *_IMPORT_FROM_TRITON,
21
+ ]
22
+
23
+ for name in _IMPORT_FROM_TRITON:
24
+ # Convert JITFunction -> GluonJitFunction
25
+ fn = getattr(tl_standard, name)
26
+ assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction)
27
+ globals()[name] = jit(fn.fn)
28
+
29
+
30
+ @jit
31
+ def zeros(shape, dtype, layout):
32
+ return ttgl.full(shape, 0, dtype, layout)
33
+
34
+
35
+ @jit
36
+ def full_like(input, value, shape=None, dtype=None, layout=None):
37
+ return ttgl.full(
38
+ input.shape if shape is None else shape,
39
+ value,
40
+ input.dtype if dtype is None else dtype,
41
+ input.type.layout if layout is None else layout,
42
+ )
43
+
44
+
45
+ @jit
46
+ def zeros_like(input, shape=None, dtype=None, layout=None):
47
+ return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)
@@ -0,0 +1,4 @@
1
+ from . import blackwell
2
+ from . import hopper
3
+
4
+ __all__ = ["blackwell", "hopper"]
@@ -0,0 +1,202 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, Tuple, List, TYPE_CHECKING
3
+
4
+ from dataclasses import dataclass
5
+ from triton.experimental.gluon.language import _core as ttgl
6
+ from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
7
+ from triton.experimental.gluon.language._semantic import _check
8
+
9
+ from . import tma
10
+ from ..hopper import mbarrier, fence_async_shared
11
+
12
+ if TYPE_CHECKING:
13
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
14
+ from triton._C.libtriton import gluon_ir as ir
15
+ from ..._semantic import GluonSemantic
16
+
17
+ __all__ = [
18
+ "allocate_tensor_memory",
19
+ "fence_async_shared",
20
+ "mbarrier",
21
+ "tensor_memory_descriptor",
22
+ "TensorMemoryLayout",
23
+ "tma",
24
+ ]
25
+
26
+
27
+ @dataclass(frozen=True, eq=True)
28
+ class TensorMemoryLayout:
29
+ block: Tuple[int, int]
30
+ unpacked: bool
31
+ cta_split_num: Optional[Tuple[int, int]] = None
32
+
33
+ def __post_init__(self):
34
+ assert len(self.block) == 2
35
+ assert self.cta_split_num is None or len(self.cta_split_num) == 2
36
+
37
+ def _to_ir(self, builder):
38
+ cta_split_num = self.cta_split_num or [1, 1]
39
+ return builder.get_tensor_memory_layout(
40
+ self.block,
41
+ self.unpacked,
42
+ cta_split_num,
43
+ )
44
+
45
+ def mangle(self) -> str:
46
+ block_str = f"{self.block[0]}x{self.block[1]}"
47
+ unpacked_str = "U" if self.unpacked else "P"
48
+ cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
49
+ return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
50
+
51
+
52
+ class tensor_memory_descriptor_type(base_type):
53
+
54
+ def __init__(self, element_ty, shape, layout, alloc_shape):
55
+ self.element_ty = element_ty
56
+ self.shape = shape
57
+ self.layout = layout
58
+ self.alloc_shape = alloc_shape
59
+ assert isinstance(layout, TensorMemoryLayout)
60
+
61
+ def to_ir(self, builder: GluonOpBuilder) -> None:
62
+ return builder.get_tensor_mem_desc_ty(
63
+ self.element_ty.to_ir(builder),
64
+ self.shape,
65
+ self.layout._to_ir(builder),
66
+ self.alloc_shape,
67
+ )
68
+
69
+ def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[tensor_memory_descriptor, int]:
70
+ value = tensor_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
71
+ return value, cursor + 1
72
+
73
+ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
74
+ out.append(self.to_ir(builder))
75
+
76
+ def __str__(self) -> str:
77
+ return f"tensor_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>"
78
+
79
+ def __eq__(self, other) -> bool:
80
+ return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
81
+ and self.alloc_shape == other.alloc_shape)
82
+
83
+ def __neq__(self, other) -> bool:
84
+ return not (self == other)
85
+
86
+ def mangle(self) -> str:
87
+ shape_str = "_".join([str(s) for s in self.shape])
88
+ return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
89
+
90
+
91
+ class tensor_memory_descriptor(base_value):
92
+
93
+ def __init__(self, handle, element_ty, shape, layout, alloc_shape):
94
+ self.handle = handle
95
+ self.type = tensor_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
96
+
97
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
98
+ handles.append(self.handle)
99
+
100
+ @property
101
+ def dtype(self):
102
+ return self.type.element_ty
103
+
104
+ @property
105
+ def shape(self):
106
+ return self.type.shape
107
+
108
+ @property
109
+ def rank(self):
110
+ return len(self.shape)
111
+
112
+ @property
113
+ def layout(self):
114
+ return self.type.layout
115
+
116
+ def __str__(self) -> str:
117
+ return str(self.type)
118
+
119
+ @builtin
120
+ def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor:
121
+ layout = _unwrap_if_constexpr(layout)
122
+ ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout)
123
+ builder = _semantic.builder
124
+ handle = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle)
125
+ return ttgl.tensor(handle, ret_ty)
126
+
127
+ @builtin
128
+ def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
129
+ pred = _unwrap_if_constexpr(pred)
130
+ pred = _semantic.to_tensor(pred)
131
+ _semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
132
+
133
+ @builtin
134
+ def slice(self, start, length, _semantic: GluonSemantic) -> None:
135
+ start = _unwrap_if_constexpr(start)
136
+ length = _unwrap_if_constexpr(length)
137
+ _check(isinstance(start, int), lambda: "start must be a constant int")
138
+ _check(isinstance(length, int), lambda: "length must be a constant int")
139
+ shape = self.shape[:-1] + [length]
140
+ layout = self.type.layout
141
+ layout = TensorMemoryLayout((layout.block[0], min(layout.block[1], length)), layout.unpacked,
142
+ layout.cta_split_num)
143
+ ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
144
+ builder = _semantic.builder
145
+ ret.handle = builder.create_tmem_subslice(ret.type.to_ir(builder), self.handle, start)
146
+ return ret
147
+
148
+ @builtin
149
+ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
150
+ index = _semantic.to_tensor(index)
151
+ builder = _semantic.builder
152
+ offsets = [builder.get_int32(0)] * self.rank
153
+ offsets[0] = index.handle
154
+ shape = self.shape[1:]
155
+ layout = self.layout
156
+ ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
157
+ ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets)
158
+ return ret
159
+
160
+ @builtin
161
+ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
162
+ dtype = _unwrap_if_constexpr(dtype)
163
+ shape = [_unwrap_if_constexpr(s) for s in shape]
164
+ layout = _unwrap_if_constexpr(layout)
165
+
166
+ ty = tensor_memory_descriptor_type(dtype, shape, layout, shape)
167
+ handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle)
168
+ return tensor_memory_descriptor(handle, **ty.__dict__)
169
+
170
+
171
+ @builtin
172
+ def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None):
173
+ element_ty = _unwrap_if_constexpr(element_ty)
174
+ shape = _unwrap_if_constexpr(shape)
175
+ layout = _unwrap_if_constexpr(layout)
176
+ value = value.handle if value is not None else None
177
+
178
+ ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape)
179
+ builder = _semantic.builder
180
+ handle = builder.create_tmem_alloc(ty.to_ir(builder), value)
181
+ return tensor_memory_descriptor(handle, element_ty, shape, layout, shape)
182
+
183
+
184
+ @builtin
185
+ def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _semantic=None):
186
+ use_acc = _semantic.to_tensor(use_acc)
187
+ pred = _semantic.to_tensor(pred)
188
+
189
+ if mbarriers is None:
190
+ assert mbarrier_preds is None
191
+ mbarriers = []
192
+ mbarrier_preds = []
193
+ else:
194
+ mbarriers = [bar.handle for bar in mbarriers]
195
+ if mbarrier_preds is None:
196
+ true = _semantic.to_tensor(True)
197
+ mbarrier_preds = [true] * len(mbarriers)
198
+ else:
199
+ mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False)
200
+
201
+ _semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers,
202
+ mbarrier_preds)
@@ -0,0 +1,32 @@
1
+ from triton.experimental.gluon.language._core import builtin
2
+ from triton.experimental.gluon.language.nvidia.hopper.tma import (
3
+ async_copy_global_to_shared,
4
+ async_copy_shared_to_global,
5
+ store_wait,
6
+ tensor_descriptor,
7
+ tensor_descriptor_type,
8
+ )
9
+
10
+ __all__ = [
11
+ "async_gather",
12
+ "async_scatter",
13
+ "async_copy_global_to_shared",
14
+ "async_copy_shared_to_global",
15
+ "store_wait",
16
+ "tensor_descriptor",
17
+ "tensor_descriptor_type",
18
+ ]
19
+
20
+
21
+ @builtin
22
+ def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None):
23
+ pred = _semantic.to_tensor(pred)
24
+ y_offset = _semantic.to_tensor(y_offset)
25
+ _semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle,
26
+ result.handle, pred.handle)
27
+
28
+
29
+ @builtin
30
+ def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None):
31
+ y_offset = _semantic.to_tensor(y_offset)
32
+ _semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle)
@@ -0,0 +1,11 @@
1
+ from . import mbarrier
2
+ from . import tma
3
+ from ... import _core
4
+
5
+ __all__ = ["fence_async_shared", "mbarrier", "tma"]
6
+
7
+
8
+ @_core.builtin
9
+ def fence_async_shared(cluster=False, _semantic=None):
10
+ cluster = _core._unwrap_if_constexpr(cluster)
11
+ _semantic.builder.create_fence_async_shared(cluster)
@@ -0,0 +1,51 @@
1
+ from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
2
+ from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
3
+
4
+ __all__ = ["MBarrierLayout", "init", "invalidate", "expect", "wait", "arrive"]
5
+
6
+
7
+ class MBarrierLayout(SwizzledSharedLayout):
8
+
9
+ def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
10
+ super().__init__(
11
+ vec=1,
12
+ per_phase=1,
13
+ max_phase=1,
14
+ order=[0],
15
+ ctas_per_cga=[ctas_per_cga],
16
+ cta_split_num=[cta_split_num],
17
+ cta_order=[0],
18
+ )
19
+
20
+
21
+ @builtin
22
+ def init(mbarrier, count, _semantic=None):
23
+ count = _unwrap_if_constexpr(count)
24
+ _semantic.builder.create_mbarrier_init(mbarrier.handle, count)
25
+
26
+
27
+ @builtin
28
+ def invalidate(mbarrier, _semantic=None):
29
+ _semantic.builder.create_mbarrier_inval(mbarrier.handle)
30
+
31
+
32
+ @builtin
33
+ def expect(mbarrier, bytes, pred=True, _semantic=None):
34
+ bytes = _unwrap_if_constexpr(bytes)
35
+ pred = _semantic.to_tensor(pred)
36
+ _semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
37
+
38
+
39
+ @builtin
40
+ def wait(mbarrier, phase, pred=True, deps=(), _semantic=None):
41
+ phase = _semantic.to_tensor(phase)
42
+ pred = _semantic.to_tensor(pred)
43
+ deps = [x.handle for x in deps]
44
+ _semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
45
+
46
+
47
+ @builtin
48
+ def arrive(mbarrier, count, pred=True, _semantic=None):
49
+ count = _unwrap_if_constexpr(count)
50
+ pred = _semantic.to_tensor(pred)
51
+ _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)