triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.4.0.post20__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
@@ -0,0 +1,312 @@
1
+ from __future__ import annotations
2
+ from typing import TypeVar, List, TYPE_CHECKING, Tuple
3
+ from functools import wraps
4
+
5
+ if TYPE_CHECKING:
6
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
7
+ from ._semantic import GluonSemantic
8
+
9
+ from ._layouts import SharedLayout, DistributedLayout
10
+ from triton._C.libtriton import ir
11
+ import triton.language.core as tl_core
12
+ from triton.language.core import (
13
+ constexpr,
14
+ base_value,
15
+ base_type,
16
+ dtype,
17
+ block_type, # TODO: block type with layout info
18
+ pointer_type,
19
+ void,
20
+ int1,
21
+ int8,
22
+ int16,
23
+ int32,
24
+ int64,
25
+ uint8,
26
+ uint16,
27
+ uint32,
28
+ uint64,
29
+ float8e5,
30
+ float8e5b16,
31
+ float8e4nv,
32
+ float8e4b8,
33
+ float8e4b15,
34
+ float16,
35
+ bfloat16,
36
+ float32,
37
+ float64,
38
+ _unwrap_if_constexpr,
39
+ _unwrap_shape,
40
+ tensor,
41
+ tuple,
42
+ tuple_type,
43
+ )
44
+
45
+ _IMPORT_FROM_TRITON: List[str] = [
46
+ "expand_dims",
47
+ "join",
48
+ "load",
49
+ "maximum",
50
+ "minimum",
51
+ "permute",
52
+ "program_id",
53
+ "reduce",
54
+ "reshape",
55
+ "split",
56
+ "static_assert",
57
+ "static_print",
58
+ "store",
59
+ "to_tensor",
60
+ "where",
61
+ "inline_asm_elementwise",
62
+ ]
63
+
64
+ __all__ = [
65
+ "constexpr",
66
+ "base_value",
67
+ "base_type",
68
+ "dtype",
69
+ "block_type",
70
+ "pointer_type",
71
+ "tuple_type",
72
+ "void",
73
+ "int1",
74
+ "int8",
75
+ "int16",
76
+ "int32",
77
+ "int64",
78
+ "uint8",
79
+ "uint16",
80
+ "uint32",
81
+ "uint64",
82
+ "float8e5",
83
+ "float8e5b16",
84
+ "float8e4nv",
85
+ "float8e4b8",
86
+ "float8e4b8",
87
+ "float8e4b15",
88
+ "float16",
89
+ "bfloat16",
90
+ "float32",
91
+ "float64",
92
+ "_unwrap_if_constexpr",
93
+ "tensor",
94
+ "tuple",
95
+ "tuple_type",
96
+ "thread_barrier",
97
+ "arange",
98
+ "full",
99
+ "convert_layout",
100
+ "allocate_shared_memory",
101
+ "shared_memory_descriptor",
102
+ "warp_specialize",
103
+ *_IMPORT_FROM_TRITON,
104
+ ]
105
+
106
+ T = TypeVar("T")
107
+
108
+ # TODO: split these
109
+ GLUON_BUILTIN = "__triton_builtin__"
110
+
111
+
112
+ class distributed_type(block_type):
113
+
114
+ def __init__(self, element_ty: dtype, shape: List[int], layout):
115
+ super().__init__(element_ty, shape)
116
+ self.layout = layout
117
+ self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>"
118
+ assert isinstance(layout, DistributedLayout)
119
+
120
+ def to_ir(self, builder: ir.builder) -> ir.type:
121
+ elem_ty = self.element_ty.to_ir(builder)
122
+ layout = self.layout._to_ir(builder)
123
+ return builder.get_distributed_ty(elem_ty, self.shape, layout)
124
+
125
+ def mangle(self) -> str:
126
+ elt = self.scalar.mangle()
127
+ shape = "_".join(map(str, self.shape))
128
+ layout = self.layout.mangle()
129
+ return f"{elt}S{shape}SL{layout}L"
130
+
131
+ def with_element_ty(self, scalar_ty: dtype) -> block_type:
132
+ return distributed_type(scalar_ty, self.shape, self.layout)
133
+
134
+
135
+ def builtin(fn: T) -> T:
136
+ """Mark a function as a builtin."""
137
+ assert callable(fn)
138
+
139
+ @wraps(fn)
140
+ def wrapper(*args, **kwargs):
141
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
142
+ raise ValueError("Did you forget to add @triton.gluon.jit ? "
143
+ "(`_semantic` argument must be provided outside of JIT functions.)")
144
+ return fn(*args, **kwargs)
145
+
146
+ setattr(wrapper, GLUON_BUILTIN, True)
147
+
148
+ return wrapper
149
+
150
+
151
+ class shared_memory_descriptor_type(base_type):
152
+
153
+ def __init__(self, element_ty, shape, layout, alloc_shape):
154
+ self.element_ty = element_ty
155
+ self.shape = shape
156
+ self.layout = layout
157
+ self.alloc_shape = alloc_shape
158
+ assert isinstance(layout, SharedLayout)
159
+
160
+ def to_ir(self, builder: GluonOpBuilder) -> None:
161
+ return builder.get_shared_mem_desc_ty(
162
+ self.element_ty.to_ir(builder),
163
+ self.shape,
164
+ self.layout._to_ir(builder),
165
+ self.alloc_shape,
166
+ )
167
+
168
+ def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]:
169
+ value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
170
+ return value, cursor + 1
171
+
172
+ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
173
+ out.append(self.to_ir(builder))
174
+
175
+ def __str__(self) -> str:
176
+ return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
177
+
178
+ def __eq__(self, other) -> bool:
179
+ return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
180
+ and self.alloc_shape == other.alloc_shape)
181
+
182
+ def __neq__(self, other) -> bool:
183
+ return not (self == other)
184
+
185
+ def mangle(self) -> str:
186
+ shape_str = "_".join([str(s) for s in self.shape])
187
+ return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
188
+
189
+
190
+ class shared_memory_descriptor(base_value):
191
+
192
+ def __init__(self, handle, element_ty, shape, layout, alloc_shape):
193
+ self.handle = handle
194
+ self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
195
+
196
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
197
+ handles.append(self.handle)
198
+
199
+ @property
200
+ def dtype(self):
201
+ return self.type.element_ty
202
+
203
+ @property
204
+ def shape(self):
205
+ return self.type.shape
206
+
207
+ @property
208
+ def rank(self):
209
+ return len(self.shape)
210
+
211
+ @property
212
+ def layout(self):
213
+ return self.type.layout
214
+
215
+ def __str__(self) -> str:
216
+ return str(self.type)
217
+
218
+ @builtin
219
+ def load(self, layout, _semantic: GluonSemantic) -> tensor:
220
+ layout = _unwrap_if_constexpr(layout)
221
+ return _semantic.shared_load(self, layout)
222
+
223
+ @builtin
224
+ def store(self, value, _semantic: GluonSemantic) -> None:
225
+ return _semantic.shared_store(self, value)
226
+
227
+ @builtin
228
+ def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
229
+ start = _unwrap_if_constexpr(start)
230
+ length = _unwrap_if_constexpr(length)
231
+ dim = _unwrap_if_constexpr(dim)
232
+ return _semantic.memdesc_slice(self, start, length, dim)
233
+
234
+ @builtin
235
+ def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
236
+ index = _unwrap_if_constexpr(index)
237
+ return _semantic.memdesc_index(self, index)
238
+
239
+ @builtin
240
+ def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
241
+ order = [_unwrap_if_constexpr(o) for o in order]
242
+ return _semantic.memdesc_trans(self, order)
243
+
244
+ @builtin
245
+ def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
246
+ shape = [_unwrap_if_constexpr(s) for s in shape]
247
+ layout = _unwrap_if_constexpr(layout)
248
+
249
+ return _semantic.memdesc_reshape(self, shape, layout)
250
+
251
+ @builtin
252
+ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
253
+ dtype = _unwrap_if_constexpr(dtype)
254
+ shape = [_unwrap_if_constexpr(s) for s in shape]
255
+ layout = _unwrap_if_constexpr(layout)
256
+
257
+ return _semantic.memdesc_reinterpret(self, dtype, shape, layout)
258
+
259
+ @builtin
260
+ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
261
+ return _semantic.shared_dealloc(self)
262
+
263
+
264
+ for name in _IMPORT_FROM_TRITON:
265
+ fn = getattr(tl_core, name)
266
+ globals()[name] = builtin(fn)
267
+
268
+
269
+ @builtin
270
+ def arange(start, end, layout, _semantic=None):
271
+ start = _unwrap_if_constexpr(start)
272
+ end = _unwrap_if_constexpr(end)
273
+ layout = _unwrap_if_constexpr(layout)
274
+ return _semantic.arange(start, end, layout)
275
+
276
+
277
+ @builtin
278
+ def convert_layout(value, layout, _semantic=None):
279
+ layout = _unwrap_if_constexpr(layout)
280
+ return _semantic.convert_layout(value, layout)
281
+
282
+
283
+ @builtin
284
+ def full(shape, value, dtype, layout, _semantic=None):
285
+ shape = _unwrap_shape(shape)
286
+ value = _unwrap_if_constexpr(value)
287
+ dtype = _unwrap_if_constexpr(dtype)
288
+ layout = _unwrap_if_constexpr(layout)
289
+ return _semantic.full(shape, value, dtype, layout)
290
+
291
+
292
+ @builtin
293
+ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None):
294
+ element_ty = _unwrap_if_constexpr(element_ty)
295
+ shape = _unwrap_if_constexpr(shape)
296
+ shape = [_unwrap_if_constexpr(s) for s in shape]
297
+ layout = _unwrap_if_constexpr(layout)
298
+ return _semantic.allocate_shared(element_ty, shape, layout, value)
299
+
300
+
301
+ @builtin
302
+ def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
303
+ _semantic=None, _generator=None):
304
+ worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
305
+ worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
306
+ return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
307
+ worker_num_regs, _generator)
308
+
309
+
310
+ @builtin
311
+ def thread_barrier(_semantic=None):
312
+ return _semantic.debug_barrier()
@@ -0,0 +1,230 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+ from triton.language.core import _unwrap_if_constexpr, _unwrap_shape
4
+
5
+ __all__ = [
6
+ "BlockedLayout",
7
+ "SliceLayout",
8
+ "DistributedLinearLayout",
9
+ "NVMMASharedLayout",
10
+ "SwizzledSharedLayout",
11
+ ]
12
+
13
+
14
+ def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order):
15
+ ctas_per_cga = ctas_per_cga or [1] * rank
16
+ cta_split_num = cta_split_num or [1] * rank
17
+ cta_order = cta_order or list(reversed(range(rank)))
18
+ return ctas_per_cga, cta_split_num, cta_order
19
+
20
+
21
+ class DistributedLayout:
22
+ pass
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class BlockedLayout(DistributedLayout):
27
+ size_per_thread: List[int]
28
+ threads_per_warp: List[int]
29
+ warps_per_cta: List[int]
30
+ order: List[int]
31
+ ctas_per_cga: Optional[List[int]] = None
32
+ cta_split_num: Optional[List[int]] = None
33
+ cta_order: Optional[List[int]] = None
34
+
35
+ def __post_init__(self):
36
+ super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread))
37
+ super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp))
38
+ super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
39
+ super().__setattr__("order", _unwrap_if_constexpr(self.order))
40
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
41
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
42
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
43
+
44
+ rank = len(self.size_per_thread)
45
+ assert len(self.threads_per_warp) == rank
46
+ assert len(self.warps_per_cta) == rank
47
+ assert len(self.order) == rank
48
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
49
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
50
+ assert self.cta_order is None or len(self.cta_order) == rank
51
+
52
+ def _to_ir(self, builder):
53
+ rank = len(self.size_per_thread)
54
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
55
+ self.cta_order)
56
+ return builder.get_blocked_layout(
57
+ self.size_per_thread,
58
+ self.threads_per_warp,
59
+ self.warps_per_cta,
60
+ self.order,
61
+ ctas_per_cga,
62
+ cta_split_num,
63
+ cta_order,
64
+ )
65
+
66
+ def mangle(self) -> str:
67
+
68
+ def stringify(x):
69
+ if x is None:
70
+ return ""
71
+ return "_".join(map(str, x))
72
+
73
+ size_per_thread = stringify(self.size_per_thread)
74
+ threads_per_warp = stringify(self.threads_per_warp)
75
+ warps_per_cta = stringify(self.warps_per_cta)
76
+ order = stringify(self.order)
77
+ ctas_per_cga = stringify(self.ctas_per_cga)
78
+ cta_split_num = stringify(self.cta_split_num)
79
+ cta_order = stringify(self.cta_order)
80
+ return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B"
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class SliceLayout(DistributedLayout):
85
+ dim: int
86
+ parent: DistributedLayout
87
+
88
+ def __post_init__(self):
89
+ super().__setattr__("dim", _unwrap_if_constexpr(self.dim))
90
+ super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
91
+
92
+ def _to_ir(self, builder):
93
+ return builder.get_slice_layout(
94
+ self.dim,
95
+ self.parent._to_ir(builder),
96
+ )
97
+
98
+ def mangle(self) -> str:
99
+ return f"SL{self.dim}_{self.parent.mangle()}SL"
100
+
101
+
102
+ @dataclass(frozen=True)
103
+ class DistributedLinearLayout(DistributedLayout):
104
+ reg_bases: List[List[int]]
105
+ lane_bases: List[List[int]]
106
+ warp_bases: List[List[int]]
107
+ block_bases: List[List[int]]
108
+ shape: List[int]
109
+
110
+ def __post_init__(self):
111
+ super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases))
112
+ super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases))
113
+ super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases))
114
+ super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
115
+ super().__setattr__("shape", _unwrap_shape(self.shape))
116
+
117
+ rank = len(self.shape)
118
+
119
+ for basis in self.reg_bases:
120
+ assert len(basis) == rank
121
+ for basis in self.lane_bases:
122
+ assert len(basis) == rank
123
+ for basis in self.warp_bases:
124
+ assert len(basis) == rank
125
+ for basis in self.block_bases:
126
+ assert len(basis) == rank
127
+
128
+ def _to_ir(self, builder):
129
+ return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases,
130
+ self.shape)
131
+
132
+ def mangle(self):
133
+ return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
134
+
135
+
136
+ class SharedLayout:
137
+ pass
138
+
139
+
140
+ @dataclass(frozen=True)
141
+ class NVMMASharedLayout(SharedLayout):
142
+ swizzle_byte_width: int
143
+ element_bitwidth: int
144
+ rank: int
145
+ transposed: bool = False
146
+ fp4_padded: bool = False
147
+ ctas_per_cga: Optional[List[int]] = None
148
+ cta_split_num: Optional[List[int]] = None
149
+ cta_order: Optional[List[int]] = None
150
+
151
+ def __post_init__(self):
152
+ super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width))
153
+ super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth))
154
+ super().__setattr__("rank", _unwrap_if_constexpr(self.rank))
155
+ super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
156
+ super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded))
157
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
158
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
159
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
160
+
161
+ assert self.element_bitwidth in [8, 16, 32, 64]
162
+ assert self.swizzle_byte_width in [0, 32, 64, 128]
163
+ rank = self.rank
164
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
165
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
166
+ assert self.cta_order is None or len(self.cta_order) == rank
167
+
168
+ def _to_ir(self, builder):
169
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(self.rank, self.ctas_per_cga, self.cta_split_num,
170
+ self.cta_order)
171
+ return builder.get_nvmma_shared_layout(
172
+ self.swizzle_byte_width,
173
+ self.element_bitwidth,
174
+ self.transposed,
175
+ self.fp4_padded,
176
+ ctas_per_cga,
177
+ cta_split_num,
178
+ cta_order,
179
+ )
180
+
181
+ def mangle(self) -> str:
182
+ return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
183
+
184
+
185
+ @dataclass(frozen=True, eq=True)
186
+ class SwizzledSharedLayout(SharedLayout):
187
+ vec: int
188
+ per_phase: int
189
+ max_phase: int
190
+ order: List[int]
191
+ ctas_per_cga: Optional[List[int]] = None
192
+ cta_split_num: Optional[List[int]] = None
193
+ cta_order: Optional[List[int]] = None
194
+
195
+ def __post_init__(self):
196
+ super().__setattr__("vec", _unwrap_if_constexpr(self.vec))
197
+ super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase))
198
+ super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase))
199
+ super().__setattr__("order", _unwrap_if_constexpr(self.order))
200
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
201
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
202
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
203
+
204
+ rank = len(self.order)
205
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
206
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
207
+ assert self.cta_order is None or len(self.cta_order) == rank
208
+
209
+ def _to_ir(self, builder):
210
+ rank = len(self.order)
211
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
212
+ self.cta_order)
213
+ return builder.get_swizzled_shared_layout(
214
+ _unwrap_if_constexpr(self.vec),
215
+ _unwrap_if_constexpr(self.per_phase),
216
+ _unwrap_if_constexpr(self.max_phase),
217
+ self.order,
218
+ ctas_per_cga,
219
+ cta_split_num,
220
+ cta_order,
221
+ )
222
+
223
+ def mangle(self) -> str:
224
+
225
+ def stringify(x):
226
+ if x is None:
227
+ return ""
228
+ return "_".join(map(str, x))
229
+
230
+ return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"
@@ -0,0 +1,12 @@
1
+ # flake8: noqa
2
+ import triton.language.math as tl_math
3
+ from ._core import builtin
4
+
5
+ __all__ = [
6
+ "umulhi", "exp", "exp2", "fma", "log", "log2", "cos", "rsqrt", "sin", "sqrt", "sqrt_rn", "abs", "fdiv", "div_rn",
7
+ "erf", "floor", "ceil"
8
+ ]
9
+
10
+ for name in __all__:
11
+ fn = getattr(tl_math, name)
12
+ globals()[name] = builtin(fn)