triton-windows 3.3.0.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
@@ -0,0 +1,312 @@
1
+ from __future__ import annotations
2
+ from typing import TypeVar, List, TYPE_CHECKING, Tuple
3
+ from functools import wraps
4
+
5
+ if TYPE_CHECKING:
6
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
7
+ from ._semantic import GluonSemantic
8
+
9
+ from ._layouts import SharedLayout, DistributedLayout
10
+ from triton._C.libtriton import ir
11
+ import triton.language.core as tl_core
12
+ from triton.language.core import (
13
+ constexpr,
14
+ base_value,
15
+ base_type,
16
+ dtype,
17
+ block_type, # TODO: block type with layout info
18
+ pointer_type,
19
+ void,
20
+ int1,
21
+ int8,
22
+ int16,
23
+ int32,
24
+ int64,
25
+ uint8,
26
+ uint16,
27
+ uint32,
28
+ uint64,
29
+ float8e5,
30
+ float8e5b16,
31
+ float8e4nv,
32
+ float8e4b8,
33
+ float8e4b15,
34
+ float16,
35
+ bfloat16,
36
+ float32,
37
+ float64,
38
+ _unwrap_if_constexpr,
39
+ _unwrap_shape,
40
+ tensor,
41
+ tuple,
42
+ tuple_type,
43
+ )
44
+
45
+ _IMPORT_FROM_TRITON: List[str] = [
46
+ "expand_dims",
47
+ "join",
48
+ "load",
49
+ "maximum",
50
+ "minimum",
51
+ "permute",
52
+ "program_id",
53
+ "reduce",
54
+ "reshape",
55
+ "split",
56
+ "static_assert",
57
+ "static_print",
58
+ "store",
59
+ "to_tensor",
60
+ "where",
61
+ "inline_asm_elementwise",
62
+ ]
63
+
64
+ __all__ = [
65
+ "constexpr",
66
+ "base_value",
67
+ "base_type",
68
+ "dtype",
69
+ "block_type",
70
+ "pointer_type",
71
+ "tuple_type",
72
+ "void",
73
+ "int1",
74
+ "int8",
75
+ "int16",
76
+ "int32",
77
+ "int64",
78
+ "uint8",
79
+ "uint16",
80
+ "uint32",
81
+ "uint64",
82
+ "float8e5",
83
+ "float8e5b16",
84
+ "float8e4nv",
85
+ "float8e4b8",
86
+ "float8e4b8",
87
+ "float8e4b15",
88
+ "float16",
89
+ "bfloat16",
90
+ "float32",
91
+ "float64",
92
+ "_unwrap_if_constexpr",
93
+ "tensor",
94
+ "tuple",
95
+ "tuple_type",
96
+ "thread_barrier",
97
+ "arange",
98
+ "full",
99
+ "convert_layout",
100
+ "allocate_shared_memory",
101
+ "shared_memory_descriptor",
102
+ "warp_specialize",
103
+ *_IMPORT_FROM_TRITON,
104
+ ]
105
+
106
+ T = TypeVar("T")
107
+
108
+ # TODO: split these
109
+ GLUON_BUILTIN = "__triton_builtin__"
110
+
111
+
112
+ class distributed_type(block_type):
113
+
114
+ def __init__(self, element_ty: dtype, shape: List[int], layout):
115
+ super().__init__(element_ty, shape)
116
+ self.layout = layout
117
+ self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>"
118
+ assert isinstance(layout, DistributedLayout)
119
+
120
+ def to_ir(self, builder: ir.builder) -> ir.type:
121
+ elem_ty = self.element_ty.to_ir(builder)
122
+ layout = self.layout._to_ir(builder)
123
+ return builder.get_distributed_ty(elem_ty, self.shape, layout)
124
+
125
+ def mangle(self) -> str:
126
+ elt = self.scalar.mangle()
127
+ shape = "_".join(map(str, self.shape))
128
+ layout = self.layout.mangle()
129
+ return f"{elt}S{shape}SL{layout}L"
130
+
131
+ def with_element_ty(self, scalar_ty: dtype) -> block_type:
132
+ return distributed_type(scalar_ty, self.shape, self.layout)
133
+
134
+
135
+ def builtin(fn: T) -> T:
136
+ """Mark a function as a builtin."""
137
+ assert callable(fn)
138
+
139
+ @wraps(fn)
140
+ def wrapper(*args, **kwargs):
141
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
142
+ raise ValueError("Did you forget to add @triton.gluon.jit ? "
143
+ "(`_semantic` argument must be provided outside of JIT functions.)")
144
+ return fn(*args, **kwargs)
145
+
146
+ setattr(wrapper, GLUON_BUILTIN, True)
147
+
148
+ return wrapper
149
+
150
+
151
+ class shared_memory_descriptor_type(base_type):
152
+
153
+ def __init__(self, element_ty, shape, layout, alloc_shape):
154
+ self.element_ty = element_ty
155
+ self.shape = shape
156
+ self.layout = layout
157
+ self.alloc_shape = alloc_shape
158
+ assert isinstance(layout, SharedLayout)
159
+
160
+ def to_ir(self, builder: GluonOpBuilder) -> None:
161
+ return builder.get_shared_mem_desc_ty(
162
+ self.element_ty.to_ir(builder),
163
+ self.shape,
164
+ self.layout._to_ir(builder),
165
+ self.alloc_shape,
166
+ )
167
+
168
+ def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]:
169
+ value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
170
+ return value, cursor + 1
171
+
172
+ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
173
+ out.append(self.to_ir(builder))
174
+
175
+ def __str__(self) -> str:
176
+ return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
177
+
178
+ def __eq__(self, other) -> bool:
179
+ return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
180
+ and self.alloc_shape == other.alloc_shape)
181
+
182
+ def __neq__(self, other) -> bool:
183
+ return not (self == other)
184
+
185
+ def mangle(self) -> str:
186
+ shape_str = "_".join([str(s) for s in self.shape])
187
+ return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
188
+
189
+
190
+ class shared_memory_descriptor(base_value):
191
+
192
+ def __init__(self, handle, element_ty, shape, layout, alloc_shape):
193
+ self.handle = handle
194
+ self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
195
+
196
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
197
+ handles.append(self.handle)
198
+
199
+ @property
200
+ def dtype(self):
201
+ return self.type.element_ty
202
+
203
+ @property
204
+ def shape(self):
205
+ return self.type.shape
206
+
207
+ @property
208
+ def rank(self):
209
+ return len(self.shape)
210
+
211
+ @property
212
+ def layout(self):
213
+ return self.type.layout
214
+
215
+ def __str__(self) -> str:
216
+ return str(self.type)
217
+
218
+ @builtin
219
+ def load(self, layout, _semantic: GluonSemantic) -> tensor:
220
+ layout = _unwrap_if_constexpr(layout)
221
+ return _semantic.shared_load(self, layout)
222
+
223
+ @builtin
224
+ def store(self, value, _semantic: GluonSemantic) -> None:
225
+ return _semantic.shared_store(self, value)
226
+
227
+ @builtin
228
+ def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
229
+ start = _unwrap_if_constexpr(start)
230
+ length = _unwrap_if_constexpr(length)
231
+ dim = _unwrap_if_constexpr(dim)
232
+ return _semantic.memdesc_slice(self, start, length, dim)
233
+
234
+ @builtin
235
+ def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
236
+ index = _unwrap_if_constexpr(index)
237
+ return _semantic.memdesc_index(self, index)
238
+
239
+ @builtin
240
+ def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
241
+ order = [_unwrap_if_constexpr(o) for o in order]
242
+ return _semantic.memdesc_trans(self, order)
243
+
244
+ @builtin
245
+ def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
246
+ shape = [_unwrap_if_constexpr(s) for s in shape]
247
+ layout = _unwrap_if_constexpr(layout)
248
+
249
+ return _semantic.memdesc_reshape(self, shape, layout)
250
+
251
+ @builtin
252
+ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
253
+ dtype = _unwrap_if_constexpr(dtype)
254
+ shape = [_unwrap_if_constexpr(s) for s in shape]
255
+ layout = _unwrap_if_constexpr(layout)
256
+
257
+ return _semantic.memdesc_reinterpret(self, dtype, shape, layout)
258
+
259
+ @builtin
260
+ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
261
+ return _semantic.shared_dealloc(self)
262
+
263
+
264
+ for name in _IMPORT_FROM_TRITON:
265
+ fn = getattr(tl_core, name)
266
+ globals()[name] = builtin(fn)
267
+
268
+
269
+ @builtin
270
+ def arange(start, end, layout, _semantic=None):
271
+ start = _unwrap_if_constexpr(start)
272
+ end = _unwrap_if_constexpr(end)
273
+ layout = _unwrap_if_constexpr(layout)
274
+ return _semantic.arange(start, end, layout)
275
+
276
+
277
+ @builtin
278
+ def convert_layout(value, layout, _semantic=None):
279
+ layout = _unwrap_if_constexpr(layout)
280
+ return _semantic.convert_layout(value, layout)
281
+
282
+
283
+ @builtin
284
+ def full(shape, value, dtype, layout, _semantic=None):
285
+ shape = _unwrap_shape(shape)
286
+ value = _unwrap_if_constexpr(value)
287
+ dtype = _unwrap_if_constexpr(dtype)
288
+ layout = _unwrap_if_constexpr(layout)
289
+ return _semantic.full(shape, value, dtype, layout)
290
+
291
+
292
+ @builtin
293
+ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None):
294
+ element_ty = _unwrap_if_constexpr(element_ty)
295
+ shape = _unwrap_if_constexpr(shape)
296
+ shape = [_unwrap_if_constexpr(s) for s in shape]
297
+ layout = _unwrap_if_constexpr(layout)
298
+ return _semantic.allocate_shared(element_ty, shape, layout, value)
299
+
300
+
301
+ @builtin
302
+ def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
303
+ _semantic=None, _generator=None):
304
+ worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
305
+ worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
306
+ return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
307
+ worker_num_regs, _generator)
308
+
309
+
310
+ @builtin
311
+ def thread_barrier(_semantic=None):
312
+ return _semantic.debug_barrier()
@@ -0,0 +1,230 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+ from triton.language.core import _unwrap_if_constexpr, _unwrap_shape
4
+
5
+ __all__ = [
6
+ "BlockedLayout",
7
+ "SliceLayout",
8
+ "DistributedLinearLayout",
9
+ "NVMMASharedLayout",
10
+ "SwizzledSharedLayout",
11
+ ]
12
+
13
+
14
+ def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order):
15
+ ctas_per_cga = ctas_per_cga or [1] * rank
16
+ cta_split_num = cta_split_num or [1] * rank
17
+ cta_order = cta_order or list(reversed(range(rank)))
18
+ return ctas_per_cga, cta_split_num, cta_order
19
+
20
+
21
+ class DistributedLayout:
22
+ pass
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class BlockedLayout(DistributedLayout):
27
+ size_per_thread: List[int]
28
+ threads_per_warp: List[int]
29
+ warps_per_cta: List[int]
30
+ order: List[int]
31
+ ctas_per_cga: Optional[List[int]] = None
32
+ cta_split_num: Optional[List[int]] = None
33
+ cta_order: Optional[List[int]] = None
34
+
35
+ def __post_init__(self):
36
+ super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread))
37
+ super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp))
38
+ super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
39
+ super().__setattr__("order", _unwrap_if_constexpr(self.order))
40
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
41
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
42
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
43
+
44
+ rank = len(self.size_per_thread)
45
+ assert len(self.threads_per_warp) == rank
46
+ assert len(self.warps_per_cta) == rank
47
+ assert len(self.order) == rank
48
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
49
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
50
+ assert self.cta_order is None or len(self.cta_order) == rank
51
+
52
+ def _to_ir(self, builder):
53
+ rank = len(self.size_per_thread)
54
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
55
+ self.cta_order)
56
+ return builder.get_blocked_layout(
57
+ self.size_per_thread,
58
+ self.threads_per_warp,
59
+ self.warps_per_cta,
60
+ self.order,
61
+ ctas_per_cga,
62
+ cta_split_num,
63
+ cta_order,
64
+ )
65
+
66
+ def mangle(self) -> str:
67
+
68
+ def stringify(x):
69
+ if x is None:
70
+ return ""
71
+ return "_".join(map(str, x))
72
+
73
+ size_per_thread = stringify(self.size_per_thread)
74
+ threads_per_warp = stringify(self.threads_per_warp)
75
+ warps_per_cta = stringify(self.warps_per_cta)
76
+ order = stringify(self.order)
77
+ ctas_per_cga = stringify(self.ctas_per_cga)
78
+ cta_split_num = stringify(self.cta_split_num)
79
+ cta_order = stringify(self.cta_order)
80
+ return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B"
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class SliceLayout(DistributedLayout):
85
+ dim: int
86
+ parent: DistributedLayout
87
+
88
+ def __post_init__(self):
89
+ super().__setattr__("dim", _unwrap_if_constexpr(self.dim))
90
+ super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
91
+
92
+ def _to_ir(self, builder):
93
+ return builder.get_slice_layout(
94
+ self.dim,
95
+ self.parent._to_ir(builder),
96
+ )
97
+
98
+ def mangle(self) -> str:
99
+ return f"SL{self.dim}_{self.parent.mangle()}SL"
100
+
101
+
102
+ @dataclass(frozen=True)
103
+ class DistributedLinearLayout(DistributedLayout):
104
+ reg_bases: List[List[int]]
105
+ lane_bases: List[List[int]]
106
+ warp_bases: List[List[int]]
107
+ block_bases: List[List[int]]
108
+ shape: List[int]
109
+
110
+ def __post_init__(self):
111
+ super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases))
112
+ super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases))
113
+ super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases))
114
+ super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
115
+ super().__setattr__("shape", _unwrap_shape(self.shape))
116
+
117
+ rank = len(self.shape)
118
+
119
+ for basis in self.reg_bases:
120
+ assert len(basis) == rank
121
+ for basis in self.lane_bases:
122
+ assert len(basis) == rank
123
+ for basis in self.warp_bases:
124
+ assert len(basis) == rank
125
+ for basis in self.block_bases:
126
+ assert len(basis) == rank
127
+
128
+ def _to_ir(self, builder):
129
+ return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases,
130
+ self.shape)
131
+
132
+ def mangle(self):
133
+ return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
134
+
135
+
136
+ class SharedLayout:
137
+ pass
138
+
139
+
140
+ @dataclass(frozen=True)
141
+ class NVMMASharedLayout(SharedLayout):
142
+ swizzle_byte_width: int
143
+ element_bitwidth: int
144
+ rank: int
145
+ transposed: bool = False
146
+ fp4_padded: bool = False
147
+ ctas_per_cga: Optional[List[int]] = None
148
+ cta_split_num: Optional[List[int]] = None
149
+ cta_order: Optional[List[int]] = None
150
+
151
+ def __post_init__(self):
152
+ super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width))
153
+ super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth))
154
+ super().__setattr__("rank", _unwrap_if_constexpr(self.rank))
155
+ super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
156
+ super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded))
157
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
158
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
159
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
160
+
161
+ assert self.element_bitwidth in [8, 16, 32, 64]
162
+ assert self.swizzle_byte_width in [0, 32, 64, 128]
163
+ rank = self.rank
164
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
165
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
166
+ assert self.cta_order is None or len(self.cta_order) == rank
167
+
168
+ def _to_ir(self, builder):
169
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(self.rank, self.ctas_per_cga, self.cta_split_num,
170
+ self.cta_order)
171
+ return builder.get_nvmma_shared_layout(
172
+ self.swizzle_byte_width,
173
+ self.element_bitwidth,
174
+ self.transposed,
175
+ self.fp4_padded,
176
+ ctas_per_cga,
177
+ cta_split_num,
178
+ cta_order,
179
+ )
180
+
181
+ def mangle(self) -> str:
182
+ return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
183
+
184
+
185
+ @dataclass(frozen=True, eq=True)
186
+ class SwizzledSharedLayout(SharedLayout):
187
+ vec: int
188
+ per_phase: int
189
+ max_phase: int
190
+ order: List[int]
191
+ ctas_per_cga: Optional[List[int]] = None
192
+ cta_split_num: Optional[List[int]] = None
193
+ cta_order: Optional[List[int]] = None
194
+
195
+ def __post_init__(self):
196
+ super().__setattr__("vec", _unwrap_if_constexpr(self.vec))
197
+ super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase))
198
+ super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase))
199
+ super().__setattr__("order", _unwrap_if_constexpr(self.order))
200
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
201
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
202
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
203
+
204
+ rank = len(self.order)
205
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
206
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
207
+ assert self.cta_order is None or len(self.cta_order) == rank
208
+
209
+ def _to_ir(self, builder):
210
+ rank = len(self.order)
211
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
212
+ self.cta_order)
213
+ return builder.get_swizzled_shared_layout(
214
+ _unwrap_if_constexpr(self.vec),
215
+ _unwrap_if_constexpr(self.per_phase),
216
+ _unwrap_if_constexpr(self.max_phase),
217
+ self.order,
218
+ ctas_per_cga,
219
+ cta_split_num,
220
+ cta_order,
221
+ )
222
+
223
+ def mangle(self) -> str:
224
+
225
+ def stringify(x):
226
+ if x is None:
227
+ return ""
228
+ return "_".join(map(str, x))
229
+
230
+ return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"
@@ -0,0 +1,12 @@
1
+ # flake8: noqa
2
+ import triton.language.math as tl_math
3
+ from ._core import builtin
4
+
5
+ __all__ = [
6
+ "umulhi", "exp", "exp2", "fma", "log", "log2", "cos", "rsqrt", "sin", "sqrt", "sqrt_rn", "abs", "fdiv", "div_rn",
7
+ "erf", "floor", "ceil"
8
+ ]
9
+
10
+ for name in __all__:
11
+ fn = getattr(tl_math, name)
12
+ globals()[name] = builtin(fn)