triton-windows 3.3.1.post21__cp310-cp310-win_amd64.whl → 3.4.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +143 -46
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +94 -94
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +296 -125
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +73 -9
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +47 -83
  59. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
  60. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
  61. triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
  64. triton/language/_utils.py +0 -21
  65. triton/language/extra/cuda/_experimental_tma.py +0 -106
  66. triton/tools/experimental_descriptor.py +0 -32
  67. triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
  68. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
@@ -0,0 +1,312 @@
1
+ from __future__ import annotations
2
+ from typing import TypeVar, List, TYPE_CHECKING, Tuple
3
+ from functools import wraps
4
+
5
+ if TYPE_CHECKING:
6
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
7
+ from ._semantic import GluonSemantic
8
+
9
+ from ._layouts import SharedLayout, DistributedLayout
10
+ from triton._C.libtriton import ir
11
+ import triton.language.core as tl_core
12
+ from triton.language.core import (
13
+ constexpr,
14
+ base_value,
15
+ base_type,
16
+ dtype,
17
+ block_type, # TODO: block type with layout info
18
+ pointer_type,
19
+ void,
20
+ int1,
21
+ int8,
22
+ int16,
23
+ int32,
24
+ int64,
25
+ uint8,
26
+ uint16,
27
+ uint32,
28
+ uint64,
29
+ float8e5,
30
+ float8e5b16,
31
+ float8e4nv,
32
+ float8e4b8,
33
+ float8e4b15,
34
+ float16,
35
+ bfloat16,
36
+ float32,
37
+ float64,
38
+ _unwrap_if_constexpr,
39
+ _unwrap_shape,
40
+ tensor,
41
+ tuple,
42
+ tuple_type,
43
+ )
44
+
45
+ _IMPORT_FROM_TRITON: List[str] = [
46
+ "expand_dims",
47
+ "join",
48
+ "load",
49
+ "maximum",
50
+ "minimum",
51
+ "permute",
52
+ "program_id",
53
+ "reduce",
54
+ "reshape",
55
+ "split",
56
+ "static_assert",
57
+ "static_print",
58
+ "store",
59
+ "to_tensor",
60
+ "where",
61
+ "inline_asm_elementwise",
62
+ ]
63
+
64
+ __all__ = [
65
+ "constexpr",
66
+ "base_value",
67
+ "base_type",
68
+ "dtype",
69
+ "block_type",
70
+ "pointer_type",
71
+ "tuple_type",
72
+ "void",
73
+ "int1",
74
+ "int8",
75
+ "int16",
76
+ "int32",
77
+ "int64",
78
+ "uint8",
79
+ "uint16",
80
+ "uint32",
81
+ "uint64",
82
+ "float8e5",
83
+ "float8e5b16",
84
+ "float8e4nv",
85
+ "float8e4b8",
86
+ "float8e4b8",
87
+ "float8e4b15",
88
+ "float16",
89
+ "bfloat16",
90
+ "float32",
91
+ "float64",
92
+ "_unwrap_if_constexpr",
93
+ "tensor",
94
+ "tuple",
95
+ "tuple_type",
96
+ "thread_barrier",
97
+ "arange",
98
+ "full",
99
+ "convert_layout",
100
+ "allocate_shared_memory",
101
+ "shared_memory_descriptor",
102
+ "warp_specialize",
103
+ *_IMPORT_FROM_TRITON,
104
+ ]
105
+
106
+ T = TypeVar("T")
107
+
108
+ # TODO: split these
109
+ GLUON_BUILTIN = "__triton_builtin__"
110
+
111
+
112
+ class distributed_type(block_type):
113
+
114
+ def __init__(self, element_ty: dtype, shape: List[int], layout):
115
+ super().__init__(element_ty, shape)
116
+ self.layout = layout
117
+ self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>"
118
+ assert isinstance(layout, DistributedLayout)
119
+
120
+ def to_ir(self, builder: ir.builder) -> ir.type:
121
+ elem_ty = self.element_ty.to_ir(builder)
122
+ layout = self.layout._to_ir(builder)
123
+ return builder.get_distributed_ty(elem_ty, self.shape, layout)
124
+
125
+ def mangle(self) -> str:
126
+ elt = self.scalar.mangle()
127
+ shape = "_".join(map(str, self.shape))
128
+ layout = self.layout.mangle()
129
+ return f"{elt}S{shape}SL{layout}L"
130
+
131
+ def with_element_ty(self, scalar_ty: dtype) -> block_type:
132
+ return distributed_type(scalar_ty, self.shape, self.layout)
133
+
134
+
135
+ def builtin(fn: T) -> T:
136
+ """Mark a function as a builtin."""
137
+ assert callable(fn)
138
+
139
+ @wraps(fn)
140
+ def wrapper(*args, **kwargs):
141
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
142
+ raise ValueError("Did you forget to add @triton.gluon.jit ? "
143
+ "(`_semantic` argument must be provided outside of JIT functions.)")
144
+ return fn(*args, **kwargs)
145
+
146
+ setattr(wrapper, GLUON_BUILTIN, True)
147
+
148
+ return wrapper
149
+
150
+
151
+ class shared_memory_descriptor_type(base_type):
152
+
153
+ def __init__(self, element_ty, shape, layout, alloc_shape):
154
+ self.element_ty = element_ty
155
+ self.shape = shape
156
+ self.layout = layout
157
+ self.alloc_shape = alloc_shape
158
+ assert isinstance(layout, SharedLayout)
159
+
160
+ def to_ir(self, builder: GluonOpBuilder) -> None:
161
+ return builder.get_shared_mem_desc_ty(
162
+ self.element_ty.to_ir(builder),
163
+ self.shape,
164
+ self.layout._to_ir(builder),
165
+ self.alloc_shape,
166
+ )
167
+
168
+ def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]:
169
+ value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
170
+ return value, cursor + 1
171
+
172
+ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
173
+ out.append(self.to_ir(builder))
174
+
175
+ def __str__(self) -> str:
176
+ return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
177
+
178
+ def __eq__(self, other) -> bool:
179
+ return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
180
+ and self.alloc_shape == other.alloc_shape)
181
+
182
+ def __neq__(self, other) -> bool:
183
+ return not (self == other)
184
+
185
+ def mangle(self) -> str:
186
+ shape_str = "_".join([str(s) for s in self.shape])
187
+ return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
188
+
189
+
190
+ class shared_memory_descriptor(base_value):
191
+
192
+ def __init__(self, handle, element_ty, shape, layout, alloc_shape):
193
+ self.handle = handle
194
+ self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
195
+
196
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
197
+ handles.append(self.handle)
198
+
199
+ @property
200
+ def dtype(self):
201
+ return self.type.element_ty
202
+
203
+ @property
204
+ def shape(self):
205
+ return self.type.shape
206
+
207
+ @property
208
+ def rank(self):
209
+ return len(self.shape)
210
+
211
+ @property
212
+ def layout(self):
213
+ return self.type.layout
214
+
215
+ def __str__(self) -> str:
216
+ return str(self.type)
217
+
218
+ @builtin
219
+ def load(self, layout, _semantic: GluonSemantic) -> tensor:
220
+ layout = _unwrap_if_constexpr(layout)
221
+ return _semantic.shared_load(self, layout)
222
+
223
+ @builtin
224
+ def store(self, value, _semantic: GluonSemantic) -> None:
225
+ return _semantic.shared_store(self, value)
226
+
227
+ @builtin
228
+ def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
229
+ start = _unwrap_if_constexpr(start)
230
+ length = _unwrap_if_constexpr(length)
231
+ dim = _unwrap_if_constexpr(dim)
232
+ return _semantic.memdesc_slice(self, start, length, dim)
233
+
234
+ @builtin
235
+ def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
236
+ index = _unwrap_if_constexpr(index)
237
+ return _semantic.memdesc_index(self, index)
238
+
239
+ @builtin
240
+ def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
241
+ order = [_unwrap_if_constexpr(o) for o in order]
242
+ return _semantic.memdesc_trans(self, order)
243
+
244
+ @builtin
245
+ def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
246
+ shape = [_unwrap_if_constexpr(s) for s in shape]
247
+ layout = _unwrap_if_constexpr(layout)
248
+
249
+ return _semantic.memdesc_reshape(self, shape, layout)
250
+
251
+ @builtin
252
+ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
253
+ dtype = _unwrap_if_constexpr(dtype)
254
+ shape = [_unwrap_if_constexpr(s) for s in shape]
255
+ layout = _unwrap_if_constexpr(layout)
256
+
257
+ return _semantic.memdesc_reinterpret(self, dtype, shape, layout)
258
+
259
+ @builtin
260
+ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
261
+ return _semantic.shared_dealloc(self)
262
+
263
+
264
+ for name in _IMPORT_FROM_TRITON:
265
+ fn = getattr(tl_core, name)
266
+ globals()[name] = builtin(fn)
267
+
268
+
269
+ @builtin
270
+ def arange(start, end, layout, _semantic=None):
271
+ start = _unwrap_if_constexpr(start)
272
+ end = _unwrap_if_constexpr(end)
273
+ layout = _unwrap_if_constexpr(layout)
274
+ return _semantic.arange(start, end, layout)
275
+
276
+
277
+ @builtin
278
+ def convert_layout(value, layout, _semantic=None):
279
+ layout = _unwrap_if_constexpr(layout)
280
+ return _semantic.convert_layout(value, layout)
281
+
282
+
283
+ @builtin
284
+ def full(shape, value, dtype, layout, _semantic=None):
285
+ shape = _unwrap_shape(shape)
286
+ value = _unwrap_if_constexpr(value)
287
+ dtype = _unwrap_if_constexpr(dtype)
288
+ layout = _unwrap_if_constexpr(layout)
289
+ return _semantic.full(shape, value, dtype, layout)
290
+
291
+
292
+ @builtin
293
+ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None):
294
+ element_ty = _unwrap_if_constexpr(element_ty)
295
+ shape = _unwrap_if_constexpr(shape)
296
+ shape = [_unwrap_if_constexpr(s) for s in shape]
297
+ layout = _unwrap_if_constexpr(layout)
298
+ return _semantic.allocate_shared(element_ty, shape, layout, value)
299
+
300
+
301
+ @builtin
302
+ def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
303
+ _semantic=None, _generator=None):
304
+ worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
305
+ worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
306
+ return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
307
+ worker_num_regs, _generator)
308
+
309
+
310
+ @builtin
311
+ def thread_barrier(_semantic=None):
312
+ return _semantic.debug_barrier()
@@ -0,0 +1,230 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+ from triton.language.core import _unwrap_if_constexpr, _unwrap_shape
4
+
5
+ __all__ = [
6
+ "BlockedLayout",
7
+ "SliceLayout",
8
+ "DistributedLinearLayout",
9
+ "NVMMASharedLayout",
10
+ "SwizzledSharedLayout",
11
+ ]
12
+
13
+
14
+ def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order):
15
+ ctas_per_cga = ctas_per_cga or [1] * rank
16
+ cta_split_num = cta_split_num or [1] * rank
17
+ cta_order = cta_order or list(reversed(range(rank)))
18
+ return ctas_per_cga, cta_split_num, cta_order
19
+
20
+
21
+ class DistributedLayout:
22
+ pass
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class BlockedLayout(DistributedLayout):
27
+ size_per_thread: List[int]
28
+ threads_per_warp: List[int]
29
+ warps_per_cta: List[int]
30
+ order: List[int]
31
+ ctas_per_cga: Optional[List[int]] = None
32
+ cta_split_num: Optional[List[int]] = None
33
+ cta_order: Optional[List[int]] = None
34
+
35
+ def __post_init__(self):
36
+ super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread))
37
+ super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp))
38
+ super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
39
+ super().__setattr__("order", _unwrap_if_constexpr(self.order))
40
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
41
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
42
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
43
+
44
+ rank = len(self.size_per_thread)
45
+ assert len(self.threads_per_warp) == rank
46
+ assert len(self.warps_per_cta) == rank
47
+ assert len(self.order) == rank
48
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
49
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
50
+ assert self.cta_order is None or len(self.cta_order) == rank
51
+
52
+ def _to_ir(self, builder):
53
+ rank = len(self.size_per_thread)
54
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
55
+ self.cta_order)
56
+ return builder.get_blocked_layout(
57
+ self.size_per_thread,
58
+ self.threads_per_warp,
59
+ self.warps_per_cta,
60
+ self.order,
61
+ ctas_per_cga,
62
+ cta_split_num,
63
+ cta_order,
64
+ )
65
+
66
+ def mangle(self) -> str:
67
+
68
+ def stringify(x):
69
+ if x is None:
70
+ return ""
71
+ return "_".join(map(str, x))
72
+
73
+ size_per_thread = stringify(self.size_per_thread)
74
+ threads_per_warp = stringify(self.threads_per_warp)
75
+ warps_per_cta = stringify(self.warps_per_cta)
76
+ order = stringify(self.order)
77
+ ctas_per_cga = stringify(self.ctas_per_cga)
78
+ cta_split_num = stringify(self.cta_split_num)
79
+ cta_order = stringify(self.cta_order)
80
+ return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B"
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class SliceLayout(DistributedLayout):
85
+ dim: int
86
+ parent: DistributedLayout
87
+
88
+ def __post_init__(self):
89
+ super().__setattr__("dim", _unwrap_if_constexpr(self.dim))
90
+ super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
91
+
92
+ def _to_ir(self, builder):
93
+ return builder.get_slice_layout(
94
+ self.dim,
95
+ self.parent._to_ir(builder),
96
+ )
97
+
98
+ def mangle(self) -> str:
99
+ return f"SL{self.dim}_{self.parent.mangle()}SL"
100
+
101
+
102
+ @dataclass(frozen=True)
103
+ class DistributedLinearLayout(DistributedLayout):
104
+ reg_bases: List[List[int]]
105
+ lane_bases: List[List[int]]
106
+ warp_bases: List[List[int]]
107
+ block_bases: List[List[int]]
108
+ shape: List[int]
109
+
110
+ def __post_init__(self):
111
+ super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases))
112
+ super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases))
113
+ super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases))
114
+ super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
115
+ super().__setattr__("shape", _unwrap_shape(self.shape))
116
+
117
+ rank = len(self.shape)
118
+
119
+ for basis in self.reg_bases:
120
+ assert len(basis) == rank
121
+ for basis in self.lane_bases:
122
+ assert len(basis) == rank
123
+ for basis in self.warp_bases:
124
+ assert len(basis) == rank
125
+ for basis in self.block_bases:
126
+ assert len(basis) == rank
127
+
128
+ def _to_ir(self, builder):
129
+ return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases,
130
+ self.shape)
131
+
132
+ def mangle(self):
133
+ return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
134
+
135
+
136
+ class SharedLayout:
137
+ pass
138
+
139
+
140
+ @dataclass(frozen=True)
141
+ class NVMMASharedLayout(SharedLayout):
142
+ swizzle_byte_width: int
143
+ element_bitwidth: int
144
+ rank: int
145
+ transposed: bool = False
146
+ fp4_padded: bool = False
147
+ ctas_per_cga: Optional[List[int]] = None
148
+ cta_split_num: Optional[List[int]] = None
149
+ cta_order: Optional[List[int]] = None
150
+
151
+ def __post_init__(self):
152
+ super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width))
153
+ super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth))
154
+ super().__setattr__("rank", _unwrap_if_constexpr(self.rank))
155
+ super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
156
+ super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded))
157
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
158
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
159
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
160
+
161
+ assert self.element_bitwidth in [8, 16, 32, 64]
162
+ assert self.swizzle_byte_width in [0, 32, 64, 128]
163
+ rank = self.rank
164
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
165
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
166
+ assert self.cta_order is None or len(self.cta_order) == rank
167
+
168
+ def _to_ir(self, builder):
169
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(self.rank, self.ctas_per_cga, self.cta_split_num,
170
+ self.cta_order)
171
+ return builder.get_nvmma_shared_layout(
172
+ self.swizzle_byte_width,
173
+ self.element_bitwidth,
174
+ self.transposed,
175
+ self.fp4_padded,
176
+ ctas_per_cga,
177
+ cta_split_num,
178
+ cta_order,
179
+ )
180
+
181
+ def mangle(self) -> str:
182
+ return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
183
+
184
+
185
+ @dataclass(frozen=True, eq=True)
186
+ class SwizzledSharedLayout(SharedLayout):
187
+ vec: int
188
+ per_phase: int
189
+ max_phase: int
190
+ order: List[int]
191
+ ctas_per_cga: Optional[List[int]] = None
192
+ cta_split_num: Optional[List[int]] = None
193
+ cta_order: Optional[List[int]] = None
194
+
195
+ def __post_init__(self):
196
+ super().__setattr__("vec", _unwrap_if_constexpr(self.vec))
197
+ super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase))
198
+ super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase))
199
+ super().__setattr__("order", _unwrap_if_constexpr(self.order))
200
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
201
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
202
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
203
+
204
+ rank = len(self.order)
205
+ assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
206
+ assert self.cta_split_num is None or len(self.cta_split_num) == rank
207
+ assert self.cta_order is None or len(self.cta_order) == rank
208
+
209
+ def _to_ir(self, builder):
210
+ rank = len(self.order)
211
+ ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
212
+ self.cta_order)
213
+ return builder.get_swizzled_shared_layout(
214
+ _unwrap_if_constexpr(self.vec),
215
+ _unwrap_if_constexpr(self.per_phase),
216
+ _unwrap_if_constexpr(self.max_phase),
217
+ self.order,
218
+ ctas_per_cga,
219
+ cta_split_num,
220
+ cta_order,
221
+ )
222
+
223
+ def mangle(self) -> str:
224
+
225
+ def stringify(x):
226
+ if x is None:
227
+ return ""
228
+ return "_".join(map(str, x))
229
+
230
+ return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"
@@ -0,0 +1,12 @@
1
+ # flake8: noqa
2
+ import triton.language.math as tl_math
3
+ from ._core import builtin
4
+
5
+ __all__ = [
6
+ "umulhi", "exp", "exp2", "fma", "log", "log2", "cos", "rsqrt", "sin", "sqrt", "sqrt_rn", "abs", "fdiv", "div_rn",
7
+ "erf", "floor", "ceil"
8
+ ]
9
+
10
+ for name in __all__:
11
+ fn = getattr(tl_math, name)
12
+ globals()[name] = builtin(fn)