triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.4.0.post20__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
@@ -0,0 +1,287 @@
1
+ from typing import Sequence, List, TypeVar, Tuple, Callable
2
+ from triton.language.semantic import TritonSemantic
3
+ from . import _core as ttgl
4
+ from ._layouts import SliceLayout
5
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
6
+ from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
7
+
8
+ TensorTy = TypeVar("TensorTy")
9
+
10
+
11
+ def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
12
+ if not cond:
13
+ raise category(msg_fn())
14
+
15
+
16
+ class GluonSemantic(TritonSemantic[TensorTy]):
17
+ tensor = ttgl.tensor
18
+ lang = ttgl
19
+
20
+ builder: GluonOpBuilder
21
+
22
+ def __init__(self, builder: GluonOpBuilder):
23
+ self.builder = builder
24
+
25
+ def _wrap_tensor_infer_layout(self, tensor):
26
+ ty = ttgl.distributed_type(tensor.type.scalar, tensor.shape,
27
+ self.builder.get_gluon_layout_from_tensor(tensor.handle))
28
+ return self.tensor(tensor.handle, ty)
29
+
30
+ def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
31
+ if len(lhs_shape) != len(rhs_shape):
32
+ raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")
33
+
34
+ ret_shape = []
35
+ for i, left in enumerate(lhs_shape):
36
+ right = rhs_shape[i]
37
+ if left == 1:
38
+ ret_shape.append(right)
39
+ elif (right == 1) or (right == left):
40
+ ret_shape.append(left)
41
+ else:
42
+ raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
43
+ "at index " + str(i) + ": " + str(left) + " and " + str(right))
44
+ return ret_shape
45
+
46
+ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
47
+ dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape]
48
+ dst_shape.insert(axis, 1)
49
+
50
+ if axis < 0:
51
+ axis += len(input.shape)
52
+
53
+ _check(isinstance(input.type, ttgl.distributed_type),
54
+ lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
55
+ layout = input.type.layout
56
+ _check(isinstance(layout, SliceLayout),
57
+ lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
58
+ _check(layout.dim == axis,
59
+ lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
60
+
61
+ ret_ty = ttgl.distributed_type(input.type.scalar, dst_shape, layout.parent)
62
+ handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder))
63
+ return self.tensor(handle, ret_ty)
64
+
65
+ def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
66
+ a, b = self.broadcast_impl_value(a, b)
67
+ _check(a.shape != [], "Cannot join scalars in gluon")
68
+ value = super().join(a, b)
69
+ return self._wrap_tensor_infer_layout(value)
70
+
71
+ def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
72
+ lhs, rhs = super().split(a)
73
+ return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs)
74
+
75
+ def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
76
+ value = super().permute(input, dims)
77
+ return self._wrap_tensor_infer_layout(value)
78
+
79
+ def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
80
+ _check(isinstance(input.type, ttgl.distributed_type),
81
+ lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
82
+ src_shape = input.type.get_block_shapes()
83
+ _check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
84
+ if shape == src_shape:
85
+ return input
86
+ for i, item in enumerate(src_shape):
87
+ if shape[i] != item and item != 1:
88
+ raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
89
+ f" must match the existing size ({item}) at non-singleton dimension"
90
+ f" {i}: {src_shape}, {shape}")
91
+ ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout)
92
+ handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder))
93
+ return self.tensor(handle, ret_ty)
94
+
95
+ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
96
+ lhs_ty = lhs.type
97
+ rhs_ty = rhs.type
98
+
99
+ if not lhs_ty.is_block() or not rhs_ty.is_block():
100
+ return super().broadcast_impl_value(lhs, rhs)
101
+
102
+ _check(isinstance(lhs_ty, ttgl.distributed_type),
103
+ lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}")
104
+ _check(isinstance(rhs_ty, ttgl.distributed_type),
105
+ lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}")
106
+
107
+ lhs_shape = lhs_ty.get_block_shapes()
108
+ rhs_shape = rhs_ty.get_block_shapes()
109
+ ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
110
+ if lhs_ty.layout != rhs_ty.layout:
111
+ raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
112
+
113
+ lhs = self.broadcast_impl_shape(lhs, ret_shape)
114
+ rhs = self.broadcast_impl_shape(rhs, ret_shape)
115
+ return lhs, rhs
116
+
117
+ def arange(self, start, end, layout):
118
+ shape = [end - start]
119
+ ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
120
+ return super().arange(start, end, ret_ty=ret_ty)
121
+
122
+ def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
123
+ _check(not can_reorder, "can_reorder is not supported in gluon")
124
+ value = super().reshape(input, dst_shape, can_reorder)
125
+ return self._wrap_tensor_infer_layout(value)
126
+
127
+ def splat(self, value, shape, layout):
128
+ ret_ty = ttgl.distributed_type(value.dtype, shape, layout)
129
+ handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle)
130
+ return ttgl.tensor(handle, ret_ty)
131
+
132
+ def full(self, shape, value, dtype, layout):
133
+ scalar = self.make_scalar(value, dtype)
134
+ return self.splat(scalar, shape, layout)
135
+
136
+ def convert_layout(self, value, layout):
137
+ ty = value.type
138
+ _check(isinstance(ty, ttgl.distributed_type),
139
+ lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
140
+ ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
141
+ handle = self.builder.create_convert_layout(ret_ty.to_ir(self.builder), value.handle)
142
+ return ttgl.tensor(handle, ret_ty)
143
+
144
+ def allocate_shared(self, element_ty, shape, layout, value):
145
+ ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape)
146
+ if value is not None:
147
+ handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle)
148
+ else:
149
+ handle = self.builder.create_local_alloc(ty.to_ir(self.builder))
150
+ return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape)
151
+
152
+ def shared_load(self, mem_desc, layout):
153
+ ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout)
154
+ handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle)
155
+ return ttgl.tensor(handle, ret_ty)
156
+
157
+ def shared_store(self, mem_desc, value):
158
+ self.builder.create_local_store(mem_desc.handle, value.handle)
159
+
160
+ def shared_dealloc(self, mem_desc):
161
+ self.builder.create_local_dealloc(mem_desc.handle)
162
+
163
+ def _memdesc_subview(self, mem_desc, offsets, shape):
164
+ layout = mem_desc.layout
165
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
166
+ builder = self.builder
167
+ handle = builder.create_memdesc_subview(ty.to_ir(builder), mem_desc.handle, offsets)
168
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
169
+
170
+ def memdesc_slice(self, mem_desc, start, length, dim):
171
+ offsets = [self.builder.get_int32(0)] * mem_desc.rank
172
+ offsets[dim] = self.to_tensor(start).handle
173
+ shape = list(mem_desc.shape)
174
+ shape[dim] = length
175
+ return self._memdesc_subview(mem_desc, offsets, shape)
176
+
177
+ def memdesc_index(self, mem_desc, index):
178
+ shape = mem_desc.shape[1:]
179
+ offsets = [self.builder.get_int32(0)] * mem_desc.rank
180
+ offsets[0] = self.to_tensor(index).handle
181
+ return self._memdesc_subview(mem_desc, offsets, shape)
182
+
183
+ def memdesc_trans(self, mem_desc, order):
184
+ assert len(order) == len(
185
+ mem_desc.shape), f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match"
186
+
187
+ shape = [mem_desc.shape[i] for i in order]
188
+ alloc_shape = mem_desc.type.alloc_shape
189
+ new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
190
+ new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
191
+
192
+ handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
193
+ layout = self.builder.get_gluon_layout_from_memdesc(handle)
194
+ return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
195
+ alloc_shape=new_alloc_shape, layout=layout)
196
+
197
+ def memdesc_reshape(self, mem_desc, shape, layout):
198
+ ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
199
+ handle = self.builder.create_memdesc_reshape(ty.to_ir(self.builder), mem_desc.handle)
200
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
201
+
202
+ def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
203
+ ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
204
+ handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle)
205
+ return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
206
+
207
+ def wrap_tensor(self, x, scalar_ty, ret_shape, layout):
208
+ if ret_shape:
209
+ res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout)
210
+ else:
211
+ res_ty = scalar_ty
212
+ return self.tensor(x, res_ty)
213
+
214
+ @staticmethod
215
+ def _check_same_layout(xs):
216
+ for x in xs:
217
+ _check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}")
218
+ layouts = [x.type.layout for x in xs]
219
+ l0 = layouts[0]
220
+ _check(all(l == l0 for l in layouts[1:]),
221
+ lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
222
+
223
+ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
224
+ _check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
225
+ # get result shape
226
+ shape = inputs[0].type.shape
227
+ rank = len(shape)
228
+ _check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
229
+ self._check_same_layout(inputs)
230
+ ret_shape = [s for i, s in enumerate(shape) if i != axis]
231
+ ret_layout = SliceLayout(axis, inputs[0].type.layout)
232
+ assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
233
+
234
+ reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
235
+ region_builder_fn(reduce_op)
236
+ assert reduce_op.verify()
237
+
238
+ return tuple(
239
+ self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
240
+ for i in range(len(inputs)))
241
+
242
+ def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
243
+ worker_num_regs: Sequence[int], generator):
244
+ num_partitions = len(worker_partitions)
245
+ assert num_partitions == len(
246
+ worker_num_warps
247
+ ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"
248
+ assert num_partitions == len(
249
+ worker_num_regs
250
+ ), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts"
251
+
252
+ builder = self.builder
253
+ insert_pt = builder.get_insertion_point()
254
+
255
+ # Emit the default partition to get the result types.
256
+ default_block = builder.new_block()
257
+ builder.set_insertion_point_to_start(default_block)
258
+ default_results = generator.call_JitFunction(default_partition, args, kwargs={})
259
+ mlir_results = []
260
+ if default_results is not None:
261
+ mlir_results = flatten_values_to_ir(default_results)
262
+ builder.create_warp_yield(mlir_results)
263
+ result_types = [r.get_type() for r in mlir_results]
264
+
265
+ # Create the warp specialize op.
266
+ builder.restore_insertion_point(insert_pt)
267
+ mlir_args = flatten_values_to_ir(args)
268
+ ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
269
+ ws_op.get_default_region().push_back(default_block)
270
+ ws_op.set_requested_registers(worker_num_regs)
271
+
272
+ # Emit the partition regions.
273
+ builder.create_block_with_parent(ws_op.get_partition_op_holder(), [])
274
+ partitions_op = builder.create_warp_specialize_partitions(num_partitions)
275
+ arg_types = [arg.get_type() for arg in mlir_args]
276
+ for i in range(num_partitions):
277
+ block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
278
+ block_args = [block.get_argument(j) for j in range(len(mlir_args))]
279
+ block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
280
+ generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
281
+ builder.create_warp_return()
282
+
283
+ builder.set_insertion_point_after(ws_op.get_operation())
284
+ mlir_results = [ws_op.get_result(i) for i in range(len(result_types))]
285
+ if default_results is None:
286
+ return
287
+ return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
@@ -0,0 +1,47 @@
1
+ # flake8: noqa
2
+ import triton
3
+ import triton.language.standard as tl_standard
4
+ from .._runtime import jit
5
+ from triton import knobs
6
+ from . import _core as ttgl
7
+
8
+ _IMPORT_FROM_TRITON = [
9
+ "sum",
10
+ "max",
11
+ "min",
12
+ "reduce_or",
13
+ "xor_sum",
14
+ ]
15
+
16
+ __all__ = [
17
+ "full_like",
18
+ "zeros",
19
+ "zeros_like",
20
+ *_IMPORT_FROM_TRITON,
21
+ ]
22
+
23
+ for name in _IMPORT_FROM_TRITON:
24
+ # Convert JITFunction -> GluonJitFunction
25
+ fn = getattr(tl_standard, name)
26
+ assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction)
27
+ globals()[name] = jit(fn.fn)
28
+
29
+
30
+ @jit
31
+ def zeros(shape, dtype, layout):
32
+ return ttgl.full(shape, 0, dtype, layout)
33
+
34
+
35
+ @jit
36
+ def full_like(input, value, shape=None, dtype=None, layout=None):
37
+ return ttgl.full(
38
+ input.shape if shape is None else shape,
39
+ value,
40
+ input.dtype if dtype is None else dtype,
41
+ input.type.layout if layout is None else layout,
42
+ )
43
+
44
+
45
+ @jit
46
+ def zeros_like(input, shape=None, dtype=None, layout=None):
47
+ return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)
@@ -0,0 +1,4 @@
1
+ from . import blackwell
2
+ from . import hopper
3
+
4
+ __all__ = ["blackwell", "hopper"]
@@ -0,0 +1,202 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, Tuple, List, TYPE_CHECKING
3
+
4
+ from dataclasses import dataclass
5
+ from triton.experimental.gluon.language import _core as ttgl
6
+ from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
7
+ from triton.experimental.gluon.language._semantic import _check
8
+
9
+ from . import tma
10
+ from ..hopper import mbarrier, fence_async_shared
11
+
12
+ if TYPE_CHECKING:
13
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
14
+ from triton._C.libtriton import gluon_ir as ir
15
+ from ..._semantic import GluonSemantic
16
+
17
+ __all__ = [
18
+ "allocate_tensor_memory",
19
+ "fence_async_shared",
20
+ "mbarrier",
21
+ "tensor_memory_descriptor",
22
+ "TensorMemoryLayout",
23
+ "tma",
24
+ ]
25
+
26
+
27
+ @dataclass(frozen=True, eq=True)
28
+ class TensorMemoryLayout:
29
+ block: Tuple[int, int]
30
+ unpacked: bool
31
+ cta_split_num: Optional[Tuple[int, int]] = None
32
+
33
+ def __post_init__(self):
34
+ assert len(self.block) == 2
35
+ assert self.cta_split_num is None or len(self.cta_split_num) == 2
36
+
37
+ def _to_ir(self, builder):
38
+ cta_split_num = self.cta_split_num or [1, 1]
39
+ return builder.get_tensor_memory_layout(
40
+ self.block,
41
+ self.unpacked,
42
+ cta_split_num,
43
+ )
44
+
45
+ def mangle(self) -> str:
46
+ block_str = f"{self.block[0]}x{self.block[1]}"
47
+ unpacked_str = "U" if self.unpacked else "P"
48
+ cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
49
+ return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
50
+
51
+
52
+ class tensor_memory_descriptor_type(base_type):
53
+
54
+ def __init__(self, element_ty, shape, layout, alloc_shape):
55
+ self.element_ty = element_ty
56
+ self.shape = shape
57
+ self.layout = layout
58
+ self.alloc_shape = alloc_shape
59
+ assert isinstance(layout, TensorMemoryLayout)
60
+
61
+ def to_ir(self, builder: GluonOpBuilder) -> None:
62
+ return builder.get_tensor_mem_desc_ty(
63
+ self.element_ty.to_ir(builder),
64
+ self.shape,
65
+ self.layout._to_ir(builder),
66
+ self.alloc_shape,
67
+ )
68
+
69
+ def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[tensor_memory_descriptor, int]:
70
+ value = tensor_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
71
+ return value, cursor + 1
72
+
73
+ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
74
+ out.append(self.to_ir(builder))
75
+
76
+ def __str__(self) -> str:
77
+ return f"tensor_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>"
78
+
79
+ def __eq__(self, other) -> bool:
80
+ return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
81
+ and self.alloc_shape == other.alloc_shape)
82
+
83
+ def __neq__(self, other) -> bool:
84
+ return not (self == other)
85
+
86
+ def mangle(self) -> str:
87
+ shape_str = "_".join([str(s) for s in self.shape])
88
+ return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
89
+
90
+
91
+ class tensor_memory_descriptor(base_value):
92
+
93
+ def __init__(self, handle, element_ty, shape, layout, alloc_shape):
94
+ self.handle = handle
95
+ self.type = tensor_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
96
+
97
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
98
+ handles.append(self.handle)
99
+
100
+ @property
101
+ def dtype(self):
102
+ return self.type.element_ty
103
+
104
+ @property
105
+ def shape(self):
106
+ return self.type.shape
107
+
108
+ @property
109
+ def rank(self):
110
+ return len(self.shape)
111
+
112
+ @property
113
+ def layout(self):
114
+ return self.type.layout
115
+
116
+ def __str__(self) -> str:
117
+ return str(self.type)
118
+
119
+ @builtin
120
+ def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor:
121
+ layout = _unwrap_if_constexpr(layout)
122
+ ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout)
123
+ builder = _semantic.builder
124
+ handle = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle)
125
+ return ttgl.tensor(handle, ret_ty)
126
+
127
+ @builtin
128
+ def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
129
+ pred = _unwrap_if_constexpr(pred)
130
+ pred = _semantic.to_tensor(pred)
131
+ _semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
132
+
133
+ @builtin
134
+ def slice(self, start, length, _semantic: GluonSemantic) -> None:
135
+ start = _unwrap_if_constexpr(start)
136
+ length = _unwrap_if_constexpr(length)
137
+ _check(isinstance(start, int), lambda: "start must be a constant int")
138
+ _check(isinstance(length, int), lambda: "length must be a constant int")
139
+ shape = self.shape[:-1] + [length]
140
+ layout = self.type.layout
141
+ layout = TensorMemoryLayout((layout.block[0], min(layout.block[1], length)), layout.unpacked,
142
+ layout.cta_split_num)
143
+ ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
144
+ builder = _semantic.builder
145
+ ret.handle = builder.create_tmem_subslice(ret.type.to_ir(builder), self.handle, start)
146
+ return ret
147
+
148
+ @builtin
149
+ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
150
+ index = _semantic.to_tensor(index)
151
+ builder = _semantic.builder
152
+ offsets = [builder.get_int32(0)] * self.rank
153
+ offsets[0] = index.handle
154
+ shape = self.shape[1:]
155
+ layout = self.layout
156
+ ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
157
+ ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets)
158
+ return ret
159
+
160
+ @builtin
161
+ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
162
+ dtype = _unwrap_if_constexpr(dtype)
163
+ shape = [_unwrap_if_constexpr(s) for s in shape]
164
+ layout = _unwrap_if_constexpr(layout)
165
+
166
+ ty = tensor_memory_descriptor_type(dtype, shape, layout, shape)
167
+ handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle)
168
+ return tensor_memory_descriptor(handle, **ty.__dict__)
169
+
170
+
171
+ @builtin
172
+ def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None):
173
+ element_ty = _unwrap_if_constexpr(element_ty)
174
+ shape = _unwrap_if_constexpr(shape)
175
+ layout = _unwrap_if_constexpr(layout)
176
+ value = value.handle if value is not None else None
177
+
178
+ ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape)
179
+ builder = _semantic.builder
180
+ handle = builder.create_tmem_alloc(ty.to_ir(builder), value)
181
+ return tensor_memory_descriptor(handle, element_ty, shape, layout, shape)
182
+
183
+
184
+ @builtin
185
+ def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _semantic=None):
186
+ use_acc = _semantic.to_tensor(use_acc)
187
+ pred = _semantic.to_tensor(pred)
188
+
189
+ if mbarriers is None:
190
+ assert mbarrier_preds is None
191
+ mbarriers = []
192
+ mbarrier_preds = []
193
+ else:
194
+ mbarriers = [bar.handle for bar in mbarriers]
195
+ if mbarrier_preds is None:
196
+ true = _semantic.to_tensor(True)
197
+ mbarrier_preds = [true] * len(mbarriers)
198
+ else:
199
+ mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False)
200
+
201
+ _semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers,
202
+ mbarrier_preds)
@@ -0,0 +1,32 @@
1
+ from triton.experimental.gluon.language._core import builtin
2
+ from triton.experimental.gluon.language.nvidia.hopper.tma import (
3
+ async_copy_global_to_shared,
4
+ async_copy_shared_to_global,
5
+ store_wait,
6
+ tensor_descriptor,
7
+ tensor_descriptor_type,
8
+ )
9
+
10
+ __all__ = [
11
+ "async_gather",
12
+ "async_scatter",
13
+ "async_copy_global_to_shared",
14
+ "async_copy_shared_to_global",
15
+ "store_wait",
16
+ "tensor_descriptor",
17
+ "tensor_descriptor_type",
18
+ ]
19
+
20
+
21
+ @builtin
22
+ def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None):
23
+ pred = _semantic.to_tensor(pred)
24
+ y_offset = _semantic.to_tensor(y_offset)
25
+ _semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle,
26
+ result.handle, pred.handle)
27
+
28
+
29
+ @builtin
30
+ def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None):
31
+ y_offset = _semantic.to_tensor(y_offset)
32
+ _semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle)
@@ -0,0 +1,11 @@
1
+ from . import mbarrier
2
+ from . import tma
3
+ from ... import _core
4
+
5
+ __all__ = ["fence_async_shared", "mbarrier", "tma"]
6
+
7
+
8
+ @_core.builtin
9
+ def fence_async_shared(cluster=False, _semantic=None):
10
+ cluster = _core._unwrap_if_constexpr(cluster)
11
+ _semantic.builder.create_fence_async_shared(cluster)
@@ -0,0 +1,51 @@
1
+ from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
2
+ from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
3
+
4
+ __all__ = ["MBarrierLayout", "init", "invalidate", "expect", "wait", "arrive"]
5
+
6
+
7
+ class MBarrierLayout(SwizzledSharedLayout):
8
+
9
+ def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
10
+ super().__init__(
11
+ vec=1,
12
+ per_phase=1,
13
+ max_phase=1,
14
+ order=[0],
15
+ ctas_per_cga=[ctas_per_cga],
16
+ cta_split_num=[cta_split_num],
17
+ cta_order=[0],
18
+ )
19
+
20
+
21
+ @builtin
22
+ def init(mbarrier, count, _semantic=None):
23
+ count = _unwrap_if_constexpr(count)
24
+ _semantic.builder.create_mbarrier_init(mbarrier.handle, count)
25
+
26
+
27
+ @builtin
28
+ def invalidate(mbarrier, _semantic=None):
29
+ _semantic.builder.create_mbarrier_inval(mbarrier.handle)
30
+
31
+
32
+ @builtin
33
+ def expect(mbarrier, bytes, pred=True, _semantic=None):
34
+ bytes = _unwrap_if_constexpr(bytes)
35
+ pred = _semantic.to_tensor(pred)
36
+ _semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
37
+
38
+
39
+ @builtin
40
+ def wait(mbarrier, phase, pred=True, deps=(), _semantic=None):
41
+ phase = _semantic.to_tensor(phase)
42
+ pred = _semantic.to_tensor(pred)
43
+ deps = [x.handle for x in deps]
44
+ _semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
45
+
46
+
47
+ @builtin
48
+ def arrive(mbarrier, count, pred=True, _semantic=None):
49
+ count = _unwrap_if_constexpr(count)
50
+ pred = _semantic.to_tensor(pred)
51
+ _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)