triton-windows 3.4.0.post20__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -1,29 +1,52 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import List, Optional
3
- from triton.language.core import _unwrap_if_constexpr, _unwrap_shape
3
+ from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type
4
+ from triton.runtime.jit import constexpr_function
4
5
 
5
- __all__ = [
6
- "BlockedLayout",
7
- "SliceLayout",
8
- "DistributedLinearLayout",
9
- "NVMMASharedLayout",
10
- "SwizzledSharedLayout",
11
- ]
12
6
 
13
-
14
- def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order):
15
- ctas_per_cga = ctas_per_cga or [1] * rank
16
- cta_split_num = cta_split_num or [1] * rank
17
- cta_order = cta_order or list(reversed(range(rank)))
18
- return ctas_per_cga, cta_split_num, cta_order
7
+ def _realize_cta_layout(layout, rank):
8
+ ctas_per_cga = layout.ctas_per_cga or [1] * rank
9
+ cta_split_num = layout.cta_split_num or [1] * rank
10
+ cta_order = layout.cta_order or list(reversed(range(rank)))
11
+ object.__setattr__(layout, "ctas_per_cga", ctas_per_cga)
12
+ object.__setattr__(layout, "cta_split_num", cta_split_num)
13
+ object.__setattr__(layout, "cta_order", cta_order)
19
14
 
20
15
 
21
16
  class DistributedLayout:
22
- pass
17
+ """
18
+ Base class for distributed memory layouts in Gluon IR.
19
+ """
20
+
21
+ @property
22
+ def type(self):
23
+ return constexpr_type(self)
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class AutoLayout(DistributedLayout):
28
+
29
+ def _to_ir(self, builder):
30
+ return builder.get_auto_layout()
31
+
32
+ def mangle(self):
33
+ return "AL"
23
34
 
24
35
 
25
36
  @dataclass(frozen=True)
26
37
  class BlockedLayout(DistributedLayout):
38
+ """
39
+ Represents a blocked layout, partitioning a tensor across threads, warps, and CTAs.
40
+
41
+ Args:
42
+ size_per_thread (List[int]): Number of elements per thread per dimension.
43
+ threads_per_warp (List[int]): Number of threads per warp per dimension.
44
+ warps_per_cta (List[int]): Number of warps per CTA per dimension.
45
+ order (List[int]): The ordering of dimensions for partitioning.
46
+ ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
47
+ cta_split_num (Optional[List[int]]): Split factors for CTAs.
48
+ cta_order (Optional[List[int]]): Ordering for CTAs.
49
+ """
27
50
  size_per_thread: List[int]
28
51
  threads_per_warp: List[int]
29
52
  warps_per_cta: List[int]
@@ -42,25 +65,23 @@ class BlockedLayout(DistributedLayout):
42
65
  super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
43
66
 
44
67
  rank = len(self.size_per_thread)
68
+ _realize_cta_layout(self, rank)
45
69
  assert len(self.threads_per_warp) == rank
46
70
  assert len(self.warps_per_cta) == rank
47
71
  assert len(self.order) == rank
48
- assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
49
- assert self.cta_split_num is None or len(self.cta_split_num) == rank
50
- assert self.cta_order is None or len(self.cta_order) == rank
72
+ assert len(self.ctas_per_cga) == rank
73
+ assert len(self.cta_split_num) == rank
74
+ assert len(self.cta_order) == rank
51
75
 
52
76
  def _to_ir(self, builder):
53
- rank = len(self.size_per_thread)
54
- ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
55
- self.cta_order)
56
77
  return builder.get_blocked_layout(
57
78
  self.size_per_thread,
58
79
  self.threads_per_warp,
59
80
  self.warps_per_cta,
60
81
  self.order,
61
- ctas_per_cga,
62
- cta_split_num,
63
- cta_order,
82
+ self.ctas_per_cga,
83
+ self.cta_split_num,
84
+ self.cta_order,
64
85
  )
65
86
 
66
87
  def mangle(self) -> str:
@@ -79,9 +100,27 @@ class BlockedLayout(DistributedLayout):
79
100
  cta_order = stringify(self.cta_order)
80
101
  return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B"
81
102
 
103
+ def __hash__(self):
104
+ return hash((
105
+ tuple(self.size_per_thread),
106
+ tuple(self.threads_per_warp),
107
+ tuple(self.warps_per_cta),
108
+ tuple(self.order),
109
+ tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
110
+ tuple(self.cta_split_num) if self.cta_split_num else None,
111
+ tuple(self.cta_order) if self.cta_order else None,
112
+ ))
113
+
82
114
 
83
115
  @dataclass(frozen=True)
84
116
  class SliceLayout(DistributedLayout):
117
+ """
118
+ Represents a layout corresponding to slicing a distributed tensor along one dimension.
119
+
120
+ Args:
121
+ dim (int): The dimension index to slice.
122
+ parent (DistributedLayout): The parent layout before slicing.
123
+ """
85
124
  dim: int
86
125
  parent: DistributedLayout
87
126
 
@@ -98,9 +137,23 @@ class SliceLayout(DistributedLayout):
98
137
  def mangle(self) -> str:
99
138
  return f"SL{self.dim}_{self.parent.mangle()}SL"
100
139
 
140
+ def __hash__(self):
141
+ return hash((self.dim, self.parent))
142
+
101
143
 
102
144
  @dataclass(frozen=True)
103
145
  class DistributedLinearLayout(DistributedLayout):
146
+ """
147
+ Represents a linear distributed layout with explicit bases at register, lane, warp, and block levels.
148
+ See: https://arxiv.org/abs/2505.23819 for reference.
149
+
150
+ Args:
151
+ reg_bases (List[List[int]]): Bases for register-level distribution.
152
+ lane_bases (List[List[int]]): Bases for lane-level distribution.
153
+ warp_bases (List[List[int]]): Bases for warp-level distribution.
154
+ block_bases (List[List[int]]): Bases for block-level distribution.
155
+ shape (List[int]): The tensor global shape.
156
+ """
104
157
  reg_bases: List[List[int]]
105
158
  lane_bases: List[List[int]]
106
159
  warp_bases: List[List[int]]
@@ -132,13 +185,128 @@ class DistributedLinearLayout(DistributedLayout):
132
185
  def mangle(self):
133
186
  return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
134
187
 
188
+ def __hash__(self):
189
+ return hash((
190
+ tuple(map(tuple, self.reg_bases)),
191
+ tuple(map(tuple, self.lane_bases)),
192
+ tuple(map(tuple, self.warp_bases)),
193
+ tuple(map(tuple, self.block_bases)),
194
+ tuple(self.shape),
195
+ ))
196
+
197
+
198
+ @dataclass(frozen=True)
199
+ class DotOperandLayout(DistributedLayout):
200
+ """
201
+ Represents a layout for a dot operand.
202
+
203
+ Args:
204
+ operand_index (int): 0 for LHS and 1 for RHS of the dot operation.
205
+ parent (DistributedLayout): The parent layout, representing the MMA.
206
+ k_width (int): Number of elements per 32-bits.
207
+ """
208
+ operand_index: int
209
+ parent: DistributedLayout
210
+ k_width: int
211
+
212
+ def __post_init__(self):
213
+ super().__setattr__("operand_index", _unwrap_if_constexpr(self.operand_index))
214
+ super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
215
+ super().__setattr__("k_width", _unwrap_if_constexpr(self.k_width))
216
+
217
+ def _to_ir(self, builder):
218
+ return builder.get_dot_operand_layout(self.operand_index, self.parent._to_ir(builder), self.k_width)
219
+
220
+ def mangle(self) -> str:
221
+ return f"DO{self.operand_index}_{self.parent.mangle()}_{self.k_width}DO"
222
+
223
+ def __hash__(self):
224
+ return hash((self.operand_index, self.parent, self.k_width))
225
+
226
+
227
+ @dataclass(frozen=True, eq=True)
228
+ class NVMMADistributedLayout(DistributedLayout):
229
+ """
230
+ Represents a layout for NVIDIA MMA (tensor core) operations.
231
+
232
+ Args:
233
+ version (List[int]): Version identifier for the MMA instruction.
234
+ warps_per_cta (List[int]): Number of warps per CTA.
235
+ instr_shape (List[int]): Instruction shape for MMA.
236
+ ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
237
+ cta_split_num (Optional[List[int]]): Split factors for CTAs.
238
+ cta_order (Optional[List[int]]): CTA ordering.
239
+ """
240
+ version: List[int]
241
+ warps_per_cta: List[int]
242
+ instr_shape: List[int]
243
+ ctas_per_cga: Optional[List[int]] = None
244
+ cta_split_num: Optional[List[int]] = None
245
+ cta_order: Optional[List[int]] = None
246
+
247
+ def __post_init__(self):
248
+ super().__setattr__("version", _unwrap_if_constexpr(self.version))
249
+ super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
250
+ super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
251
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
252
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
253
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
254
+
255
+ rank = len(self.warps_per_cta)
256
+ _realize_cta_layout(self, rank)
257
+ assert len(self.ctas_per_cga) == rank
258
+ assert len(self.cta_split_num) == rank
259
+ assert len(self.cta_order) == rank
260
+
261
+ def _to_ir(self, builder):
262
+ return builder.get_mma_layout(self.version, self.warps_per_cta, self.ctas_per_cga, self.cta_split_num,
263
+ self.cta_order, self.instr_shape)
264
+
265
+ def mangle(self) -> str:
266
+ return f"MMA_{self.version}_{self.warps_per_cta}_{self.instr_shape}_{self.ctas_per_cga}_{self.cta_split_num}_{self.cta_order}_MMA"
267
+
268
+ def __hash__(self):
269
+ return hash((tuple(self.version), tuple(self.warps_per_cta),
270
+ tuple(self.instr_shape), tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
271
+ tuple(self.cta_split_num) if self.cta_split_num else None,
272
+ tuple(self.cta_order) if self.cta_order else None))
273
+
135
274
 
136
275
  class SharedLayout:
137
- pass
276
+ """
277
+ Base class for shared memory layouts in Gluon IR.
278
+ """
279
+
280
+ @property
281
+ def type(self):
282
+ return constexpr_type(self)
283
+
284
+
285
+ @constexpr_function
286
+ def _get_shape_per_cta(shape, cta_split_num):
287
+ shape_per_cta = shape
288
+ if cta_split_num is not None:
289
+ assert len(cta_split_num) == len(shape)
290
+ for dim in range(len(shape_per_cta)):
291
+ shape_per_cta[dim] /= cta_split_num[dim]
292
+ return shape_per_cta
138
293
 
139
294
 
140
295
  @dataclass(frozen=True)
141
296
  class NVMMASharedLayout(SharedLayout):
297
+ """
298
+ Represents a layout for shared memory suitable for NVIDIA MMA operations.
299
+
300
+ Args:
301
+ swizzle_byte_width (int): Width in bytes for swizzling.
302
+ element_bitwidth (int): Bitwidth of element type.
303
+ rank (int): Rank of the tensor.
304
+ transposed (bool): Whether the layout is transposed.
305
+ fp4_padded (bool): Whether FP4 padding is used.
306
+ ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
307
+ cta_split_num (Optional[List[int]]): Split factors for CTAs.
308
+ cta_order (Optional[List[int]]): CTA ordering.
309
+ """
142
310
  swizzle_byte_width: int
143
311
  element_bitwidth: int
144
312
  rank: int
@@ -161,29 +329,88 @@ class NVMMASharedLayout(SharedLayout):
161
329
  assert self.element_bitwidth in [8, 16, 32, 64]
162
330
  assert self.swizzle_byte_width in [0, 32, 64, 128]
163
331
  rank = self.rank
164
- assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
165
- assert self.cta_split_num is None or len(self.cta_split_num) == rank
166
- assert self.cta_order is None or len(self.cta_order) == rank
332
+ _realize_cta_layout(self, rank)
333
+ assert len(self.ctas_per_cga) == rank
334
+ assert len(self.cta_split_num) == rank
335
+ assert len(self.cta_order) == rank
167
336
 
168
337
  def _to_ir(self, builder):
169
- ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(self.rank, self.ctas_per_cga, self.cta_split_num,
170
- self.cta_order)
171
338
  return builder.get_nvmma_shared_layout(
172
339
  self.swizzle_byte_width,
173
340
  self.element_bitwidth,
174
341
  self.transposed,
175
342
  self.fp4_padded,
176
- ctas_per_cga,
177
- cta_split_num,
178
- cta_order,
343
+ self.ctas_per_cga,
344
+ self.cta_split_num,
345
+ self.cta_order,
346
+ )
347
+
348
+ @staticmethod
349
+ @constexpr_function
350
+ def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, ctas_per_cga=None, cta_split_num=None,
351
+ cta_order=None):
352
+ """Returns an NVMMASharedLayout with default swizzling for a given shape.
353
+
354
+ This picks the largest swizzle pattern compatible with the shape, which
355
+ allows emitting the fewest TMA or MMA messages.
356
+ """
357
+ packing_factor = 2 if fp4_padded else 1
358
+ shape_per_cta = _get_shape_per_cta(block_shape, cta_split_num)
359
+ rank = len(block_shape)
360
+ if transposed:
361
+ shape_per_cta = shape_per_cta[1:] + shape_per_cta[:1]
362
+ contig_dim_size = shape_per_cta[-1] * packing_factor
363
+ contig_dim_bytes = contig_dim_size * dtype.primitive_bitwidth // 8
364
+ if contig_dim_bytes >= 128 and contig_dim_bytes % 128 == 0:
365
+ swizzle_byte_width = 128
366
+ elif contig_dim_bytes >= 64 and contig_dim_bytes % 64 == 0:
367
+ swizzle_byte_width = 64
368
+ elif contig_dim_bytes >= 32 and contig_dim_bytes % 32 == 0:
369
+ swizzle_byte_width = 32
370
+ else:
371
+ swizzle_byte_width = 0
372
+
373
+ flatten_outer_dim = 1
374
+ for size in shape_per_cta[:-1]:
375
+ flatten_outer_dim *= size
376
+ if len(block_shape) < 2 or flatten_outer_dim < 8:
377
+ swizzle_byte_width = 0
378
+
379
+ return NVMMASharedLayout(
380
+ swizzle_byte_width=swizzle_byte_width,
381
+ element_bitwidth=dtype.primitive_bitwidth,
382
+ rank=rank,
383
+ transposed=transposed,
384
+ fp4_padded=fp4_padded,
385
+ ctas_per_cga=ctas_per_cga,
386
+ cta_split_num=cta_split_num,
387
+ cta_order=cta_order,
179
388
  )
180
389
 
181
390
  def mangle(self) -> str:
182
391
  return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
183
392
 
393
+ def __hash__(self):
394
+ return hash((self.swizzle_byte_width, self.element_bitwidth, self.rank, self.transposed, self.fp4_padded,
395
+ tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
396
+ tuple(self.cta_split_num) if self.cta_split_num else None,
397
+ tuple(self.cta_order) if self.cta_order else None))
398
+
184
399
 
185
400
  @dataclass(frozen=True, eq=True)
186
401
  class SwizzledSharedLayout(SharedLayout):
402
+ """
403
+ Represents a generic swizzled shared memory layout.
404
+
405
+ Args:
406
+ vec (int): Vector width for swizzling.
407
+ per_phase (int): Elements per swizzle phase.
408
+ max_phase (int): Maximum number of swizzle phases.
409
+ order (List[int]): Dimension ordering for swizzling.
410
+ ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
411
+ cta_split_num (Optional[List[int]]): Split factors for CTAs.
412
+ cta_order (Optional[List[int]]): CTA ordering.
413
+ """
187
414
  vec: int
188
415
  per_phase: int
189
416
  max_phase: int
@@ -202,22 +429,20 @@ class SwizzledSharedLayout(SharedLayout):
202
429
  super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
203
430
 
204
431
  rank = len(self.order)
205
- assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
206
- assert self.cta_split_num is None or len(self.cta_split_num) == rank
207
- assert self.cta_order is None or len(self.cta_order) == rank
432
+ _realize_cta_layout(self, rank)
433
+ assert len(self.ctas_per_cga) == rank
434
+ assert len(self.cta_split_num) == rank
435
+ assert len(self.cta_order) == rank
208
436
 
209
437
  def _to_ir(self, builder):
210
- rank = len(self.order)
211
- ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
212
- self.cta_order)
213
438
  return builder.get_swizzled_shared_layout(
214
- _unwrap_if_constexpr(self.vec),
215
- _unwrap_if_constexpr(self.per_phase),
216
- _unwrap_if_constexpr(self.max_phase),
439
+ self.vec,
440
+ self.per_phase,
441
+ self.max_phase,
217
442
  self.order,
218
- ctas_per_cga,
219
- cta_split_num,
220
- cta_order,
443
+ self.ctas_per_cga,
444
+ self.cta_split_num,
445
+ self.cta_order,
221
446
  )
222
447
 
223
448
  def mangle(self) -> str:
@@ -228,3 +453,131 @@ class SwizzledSharedLayout(SharedLayout):
228
453
  return "_".join(map(str, x))
229
454
 
230
455
  return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"
456
+
457
+ def __hash__(self):
458
+ return hash((self.vec, self.per_phase, self.max_phase,
459
+ tuple(self.order), tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
460
+ tuple(self.cta_split_num) if self.cta_split_num else None,
461
+ tuple(self.cta_order) if self.cta_order else None))
462
+
463
+
464
+ @dataclass(frozen=True, eq=True)
465
+ class PaddedSharedLayout(SharedLayout):
466
+ """
467
+ Represents a layout for the access to shared memory. Compared to SwizzledSharedLayout,
468
+ it uses padding to avoid shared memory bank conflicts. After every interval tensor elements,
469
+ the corresponding number of padding elements are inserted.
470
+ If a position corresponds to multiple intervals, the padding amounts are summed.
471
+
472
+ In the following example of a tensor,
473
+ `eM` represents original elements in the and `pN` represents padded element.
474
+
475
+ Before padding, the shared memory looks like:
476
+ [e0, e1,
477
+ e2, e3,
478
+ e4, e5,
479
+ e6, e7,
480
+ ...]
481
+
482
+ After padding with interval-padding list [[2, 1], [4, 2]],
483
+ the shared memory will be
484
+ [e0, e1, p0,
485
+ e2, e3, p1, p2, p3,
486
+ e4, e5, p4,
487
+ e6, e7, p5, p6, p7,
488
+ ...]
489
+
490
+ Args:
491
+ interval_padding_pairs (List[int]): List of [interval, padding] pair and both interval and padding must be powers of 2.
492
+ order (List[int]): Order of logical tensor dimensions; fastest-varying first.
493
+ ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
494
+ cta_split_num (Optional[List[int]]): Split factors for CTAs.
495
+ cta_order (Optional[List[int]]): CTA ordering.
496
+ """
497
+ interval_padding_pairs: List[List[int]]
498
+ order: List[int]
499
+ ctas_per_cga: Optional[List[int]] = None
500
+ cta_split_num: Optional[List[int]] = None
501
+ cta_order: Optional[List[int]] = None
502
+
503
+ def __post_init__(self):
504
+ super().__setattr__("interval_padding_pairs", _unwrap_shape(self.interval_padding_pairs))
505
+ super().__setattr__("order", _unwrap_if_constexpr(self.order))
506
+ super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
507
+ super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
508
+ super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
509
+
510
+ self.verify()
511
+
512
+ def _to_ir(self, builder):
513
+ intervals, paddings = zip(*self.interval_padding_pairs)
514
+ return builder.get_padded_shared_layout(intervals, paddings, self.order, self.ctas_per_cga, self.cta_split_num,
515
+ self.cta_order)
516
+
517
+ def mangle(self) -> str:
518
+
519
+ def stringify(x):
520
+ if x is None:
521
+ return ""
522
+ return "_".join(map(str, x))
523
+
524
+ return f"PaddedShared_{stringify(self.interval_padding_pairs)}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_PaddedShared"
525
+
526
+ def verify(self):
527
+ pairs = self.interval_padding_pairs
528
+ assert len(pairs) > 0, "PaddedSharedLayout interval_padding_pairs must have at least one interval-padding pair"
529
+ assert all(len(pair) == 2 for pair in pairs)
530
+ intervals, paddings = zip(*pairs)
531
+
532
+ unique_intervals = list(set(intervals))
533
+ assert len(unique_intervals) == len(intervals)
534
+
535
+ is_power_of_2 = lambda n: n > 0 and n & (n - 1) == 0
536
+ assert all(is_power_of_2(n) for n in intervals), "PaddedSharedLayout interval values must all be power of two"
537
+ assert all(is_power_of_2(n) for n in paddings), "PaddedSharedLayout padding values must all be power of two"
538
+
539
+ rank = len(self.order)
540
+ assert rank > 0, "PaddedSharedLayout order must not be empty"
541
+ _realize_cta_layout(self, rank)
542
+
543
+ assert len(self.ctas_per_cga) == rank
544
+ assert len(self.cta_split_num) == rank
545
+ assert len(self.cta_order) == rank
546
+
547
+ def __hash__(self):
548
+ return hash((tuple(map(tuple, self.interval_padding_pairs)),
549
+ tuple(self.order), tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
550
+ tuple(self.cta_split_num) if self.cta_split_num else None,
551
+ tuple(self.cta_order) if self.cta_order else None))
552
+
553
+
554
+ # Python impl of LinearEncodingAttr::basesPerDim
555
+ def bases_per_dim(bases, rank, skip_broadcast=True):
556
+ result = [1] * rank
557
+
558
+ if not bases:
559
+ return result
560
+
561
+ non_zero_idx = None
562
+
563
+ for basis in bases:
564
+ # Find the first non-zero index in the current basis
565
+ idx = next((i for i, v in enumerate(basis) if v != 0), None)
566
+ if idx is not None:
567
+ non_zero_idx = idx
568
+ result[idx] *= 2
569
+ elif not skip_broadcast:
570
+ # If no non-zero found and we're not skipping broadcasts, use the last found non-zero index
571
+ assert non_zero_idx is not None
572
+ result[non_zero_idx] *= 2
573
+
574
+ return result
575
+
576
+
577
+ def warps_per_cta(layout, shape):
578
+ if isinstance(layout, DistributedLinearLayout):
579
+ return bases_per_dim(layout.warp_bases, len(shape))
580
+ elif isinstance(layout, (SliceLayout, DotOperandLayout)):
581
+ return warps_per_cta(layout.parent, shape)
582
+ else:
583
+ return layout.warps_per_cta
@@ -1,12 +1,20 @@
1
- # flake8: noqa
2
1
  import triton.language.math as tl_math
3
2
  from ._core import builtin
4
3
 
5
- __all__ = [
6
- "umulhi", "exp", "exp2", "fma", "log", "log2", "cos", "rsqrt", "sin", "sqrt", "sqrt_rn", "abs", "fdiv", "div_rn",
7
- "erf", "floor", "ceil"
8
- ]
9
-
10
- for name in __all__:
11
- fn = getattr(tl_math, name)
12
- globals()[name] = builtin(fn)
4
+ umulhi = builtin(tl_math.umulhi)
5
+ exp = builtin(tl_math.exp)
6
+ exp2 = builtin(tl_math.exp2)
7
+ fma = builtin(tl_math.fma)
8
+ log = builtin(tl_math.log)
9
+ log2 = builtin(tl_math.log2)
10
+ cos = builtin(tl_math.cos)
11
+ rsqrt = builtin(tl_math.rsqrt)
12
+ sin = builtin(tl_math.sin)
13
+ sqrt = builtin(tl_math.sqrt)
14
+ sqrt_rn = builtin(tl_math.sqrt_rn)
15
+ abs = builtin(tl_math.abs)
16
+ fdiv = builtin(tl_math.fdiv)
17
+ div_rn = builtin(tl_math.div_rn)
18
+ erf = builtin(tl_math.erf)
19
+ floor = builtin(tl_math.floor)
20
+ ceil = builtin(tl_math.ceil)