triton-windows 3.4.0.post20__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ from ..._core import ir, builtin, _unwrap_if_constexpr
2
+ from ..._semantic import _check
3
+ from ..._layouts import BlockedLayout, SliceLayout
4
+ from ..cdna3 import _verify_buffer_ops
5
+
6
+ __all__ = [
7
+ "global_load_to_shared",
8
+ "buffer_load_to_shared",
9
+ "async_wait",
10
+ "load_shared_relaxed",
11
+ ]
12
+
13
+
14
+ @builtin
15
+ def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _semantic=None):
16
+ """
17
+ AMD global load to shared operation. This operation loads data directly
18
+ from global memory to shared memory without going through registers. It
19
+ happens asynchronously and requires a subsequent `async_wait` to ensure the
20
+ data is available in shared memory.
21
+ Compared to `buffer_load_to_shared`, it requires a tensor pointer which
22
+ supports 64-bit indexing range for each thread in a block, which gives more
23
+ flexibility, but at the cost of higher register pressure and no hardware
24
+ out-of-bound masking support. Prefer to use `buffer_load_to_shared` when
25
+ possible for better performance.
26
+
27
+ The underlying hardware instruction uses separate registers for global
28
+ memory address for each thread but the same register for local memory
29
+ address for the whole warp. Therefore, while using this operation
30
+ the following conditions must be met or lowering to LLVM will fail:
31
+
32
+ - For the `ptr` layout, size per thread * bits per element must be 128 or 32.
33
+ To get ideal performance, it is recommended to use 128 bits per element.
34
+ - Writes to `dest` must be coalesced.
35
+ - If `dest` is swizzled, it only can be swizzled within warp boundary.
36
+
37
+ Args:
38
+ dest (shared_memory_descriptor): Destination shared memory descriptor.
39
+ ptr (pointer tensor): Tensor of pointers to global memory to load from.
40
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
41
+ other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
42
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
43
+ """
44
+ _check(ptr.type.is_block(), lambda: "expected ptr to be a tensor")
45
+ _check(isinstance(ptr.type.layout, (BlockedLayout, SliceLayout)),
46
+ lambda: "expected ptr type layout to be BlockedLayout or SliceLayout")
47
+ _check(
48
+ dest.shape == ptr.shape, lambda:
49
+ f"expected dest shape to match pointer shape but got dest.shape = {dest.shape}, pointer.shape = {ptr.shape}")
50
+
51
+ mask = _unwrap_if_constexpr(mask)
52
+ if mask is not None:
53
+ ptr, mask = _semantic.broadcast_impl_value(ptr, mask)
54
+ other = _unwrap_if_constexpr(other)
55
+ if other is not None:
56
+ ptr, other = _semantic.broadcast_impl_value(ptr, other)
57
+
58
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
59
+ mask_handle = mask.handle if mask is not None else ir.value()
60
+ other_handle = other.handle if other is not None else ir.value()
61
+ _semantic.builder.create_async_copy_global_to_local(dest.handle, ptr.handle, mask_handle, other_handle,
62
+ cache_modifier, ir.EVICTION_POLICY.NORMAL, False)
63
+
64
+
65
+ @builtin
66
+ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modifier="", _semantic=None):
67
+ """
68
+ AMD buffer load to shared operation. Buffer load is similar to global load
69
+ but it accesses global memory via a scalar base pointer and a tensor of
70
+ 32-bit offsets instead of a tensor of pointers. This operation loads data
71
+ directly from global memory to shared memory without going through
72
+ registers. It happens asynchronously and requires a subsequent `async_wait`
73
+ to ensure the data is available in shared memory.
74
+ Compared to `global_load_to_shared`, it has better performance and also
75
+ supports hardware out-of-bound masking. But it strictly requires a
76
+ 32-bit offset instead of a 64-bit tensor pointer.
77
+
78
+ The underlying hardware instruction uses separate registers for global
79
+ memory address for each thread but the same register for local memory
80
+ address for the whole warp. Therefore, while using this operation
81
+ the following conditions must be met or lowering to LLVM will fail:
82
+
83
+ - For the `offsets` layout, size per thread * bits per element must be 128 or 32.
84
+ To get ideal performance, it is recommended to use 128 bits per element.
85
+ - Writes to `dest` must be coalesced.
86
+ - If `dest` is swizzled, it only can be swizzled within warp boundary.
87
+
88
+ Args:
89
+ dest (shared_memory_descriptor): Destination shared memory descriptor.
90
+ ptr (pointer to scalar): Global memory scalar base pointer to load from.
91
+ offsets (tensor): Offsets tensor for the load operation.
92
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
93
+ other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
94
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
95
+ """
96
+ _check(isinstance(offsets.type.layout, (BlockedLayout, SliceLayout)),
97
+ lambda: "expected offsets type layout to be BlockedLayout or SliceLayout")
98
+ _verify_buffer_ops(ptr, offsets, mask, other)
99
+
100
+ mask = _unwrap_if_constexpr(mask)
101
+ if mask is not None:
102
+ offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
103
+ other = _unwrap_if_constexpr(other)
104
+ if other is not None:
105
+ offsets, other = _semantic.broadcast_impl_value(offsets, other)
106
+
107
+ mask = mask.handle if mask is not None else ir.value()
108
+ other = other.handle if other is not None else ir.value()
109
+ stride = ir.value()
110
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
111
+
112
+ _semantic.builder.create_buffer_load_to_local(dest.handle, ptr.handle, offsets.handle, mask, other, stride,
113
+ cache_modifier)
114
+
115
+
116
+ @builtin
117
+ def async_wait(num_outstanding=0, _semantic=None):
118
+ """
119
+ Wait for outstanding memory operations, this includes normal load like
120
+ `load` and `buffer_load`, as well as direct load to shared memory
121
+ like `global_load_to_shared` and `buffer_load_to_shared`.
122
+ It will block until the number of outstanding memory operations is less than
123
+ or equal to `num_outstanding`.
124
+
125
+ Args:
126
+ num_outstanding (int): The number of outstanding operations to wait for. Defaults to 0.
127
+ """
128
+ num_outstanding = _unwrap_if_constexpr(num_outstanding)
129
+ _semantic.builder.create_async_wait_group(num_outstanding)
130
+
131
+
132
+ @builtin
133
+ def load_shared_relaxed(smem, layout, _semantic=None):
134
+ """
135
+ Load a tensor from shared memory with extra hints for the underlying
136
+ compiler to avoid emitting unnecessary waits before loading from the target
137
+ shared memory.
138
+
139
+ Args:
140
+ smem (shared_memory_descriptor): Shared memory descriptor to load from.
141
+ layout (DistributedLayout): The destination layout of the tensor.
142
+
143
+ Returns:
144
+ tensor: A Gluon tensor containing the loaded data.
145
+ """
146
+ SYNCED_VIA_WAIT_ATTR_NAME = "ttg.amdgpu.syncedViaAsyncWait"
147
+
148
+ layout = _unwrap_if_constexpr(layout)
149
+ ret = _semantic.shared_load(smem, layout)
150
+ ret.handle.set_attr(SYNCED_VIA_WAIT_ATTR_NAME, _semantic.builder.get_bool_attr(True))
151
+ return ret
@@ -0,0 +1,3 @@
1
+ from triton.language.extra import libdevice
2
+
3
+ __all__ = ["libdevice"]
@@ -0,0 +1,3 @@
1
+ from . import async_copy, mbarrier
2
+
3
+ __all__ = ["async_copy", "mbarrier"]
@@ -0,0 +1,74 @@
1
+ from ..._semantic import _check
2
+ from ..._core import _unwrap_if_constexpr, builtin
3
+ from triton._C.libtriton import ir
4
+
5
+ __all__ = [
6
+ "async_copy_global_to_shared",
7
+ "mbarrier_arrive",
8
+ "commit_group",
9
+ "wait_group",
10
+ ]
11
+
12
+
13
+ @builtin
14
+ def async_copy_global_to_shared(smem, pointer, mask=None, cache_modifier="", eviction_policy="", volatile=False,
15
+ _semantic=None):
16
+ """
17
+ Asynchronously copy elements from global memory to shared memory.
18
+
19
+ Args:
20
+ smem (shared_memory_descriptor): Destination shared memory descriptor.
21
+ pointer (tensor): Source pointer tensor.
22
+ mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
23
+ cache_modifier (str): Cache modifier specifier. Defaults to "".
24
+ eviction_policy (str): Eviction policy specifier. Defaults to "".
25
+ volatile (bool): Whether the load is volatile. Defaults to False.
26
+ """
27
+ mask = _unwrap_if_constexpr(mask)
28
+ cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
29
+ eviction_policy = _semantic._str_to_eviction_policy(eviction_policy)
30
+ volatile = _unwrap_if_constexpr(volatile)
31
+ if mask is not None:
32
+ pointer, mask = _semantic.broadcast_impl_value(pointer, mask)
33
+ _check(
34
+ smem.shape == pointer.shape, lambda:
35
+ f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}"
36
+ )
37
+ mask_handle = mask.handle if mask is not None else ir.value()
38
+ _semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, ir.value(),
39
+ cache_modifier, eviction_policy, volatile)
40
+
41
+
42
+ @builtin
43
+ def mbarrier_arrive(mbarrier, increment_count=True, _semantic=None):
44
+ """
45
+ Arrive on the mbarrier once all outstanding async copies are complete.
46
+
47
+ Args:
48
+ mbarrier (shared_memory_descriptor): Barrier object to arrive on.
49
+ increment_count (bool): Whether to increment the arrival count. Defaults to True.
50
+ """
51
+ increment_count = _unwrap_if_constexpr(increment_count)
52
+ _semantic.builder.create_async_copy_mbarrier_arrive(mbarrier.handle, increment_count)
53
+
54
+
55
+ @builtin
56
+ def commit_group(_semantic=None):
57
+ """
58
+ Commit the current asynchronous copy group.
59
+
60
+ This finalizes a set of asynchronous copy operations.
61
+ """
62
+ _semantic.builder.create_async_commit_group()
63
+
64
+
65
+ @builtin
66
+ def wait_group(num_outstanding=0, _semantic=None):
67
+ """
68
+ Wait for outstanding asynchronous copy group operations.
69
+
70
+ Args:
71
+ num_outstanding (int): Wait until `num_outstanding` or less async copy groups in-flight. Defaults to 0.
72
+ """
73
+ num_outstanding = _unwrap_if_constexpr(num_outstanding)
74
+ _semantic.builder.create_async_wait_group(num_outstanding)
@@ -0,0 +1,80 @@
1
+ from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
2
+ from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
3
+
4
+ __all__ = ["arrive", "init", "invalidate", "MBarrierLayout", "wait"]
5
+
6
+
7
+ class MBarrierLayout(SwizzledSharedLayout):
8
+ """
9
+ Layout for mbarrier synchronization in Ampere and later architectures.
10
+
11
+ Args:
12
+ ctas_per_cga (int): CTAs per CGA grouping. Defaults to 1.
13
+ cta_split_num (int): CTA split factor. Defaults to 1.
14
+ """
15
+
16
+ def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
17
+ super().__init__(
18
+ vec=1,
19
+ per_phase=1,
20
+ max_phase=1,
21
+ order=[0],
22
+ ctas_per_cga=[ctas_per_cga],
23
+ cta_split_num=[cta_split_num],
24
+ cta_order=[0],
25
+ )
26
+
27
+
28
+ @builtin
29
+ def init(mbarrier, count, _semantic=None):
30
+ """
31
+ Initialize an mbarrier with a specified count.
32
+
33
+ Args:
34
+ mbarrier (shared_memory_descriptor): The barrier object to initialize.
35
+ count (int): The initial count for the barrier.
36
+ """
37
+ count = _unwrap_if_constexpr(count)
38
+ _semantic.builder.create_mbarrier_init(mbarrier.handle, count)
39
+
40
+
41
+ @builtin
42
+ def invalidate(mbarrier, _semantic=None):
43
+ """
44
+ Invalidate an mbarrier, resetting its state.
45
+
46
+ Args:
47
+ mbarrier (shared_memory_descriptor): The barrier object to invalidate.
48
+ """
49
+ _semantic.builder.create_mbarrier_inval(mbarrier.handle)
50
+
51
+
52
+ @builtin
53
+ def wait(mbarrier, phase, pred=True, deps=(), _semantic=None):
54
+ """
55
+ Wait until the mbarrier object completes its current phase.
56
+
57
+ Args:
58
+ mbarrier (shared_memory_descriptor): The barrier object to wait on.
59
+ phase (int): The phase index to wait for.
60
+ pred (bool): Predicate. Operation is skipped if predicate is False. Defaults to True.
61
+ deps (Sequence[shared_memory_descriptor]): Dependent allocations barrier is waiting on. Used to track liveness of dependent allocations. Defaults to ().
62
+ """
63
+ phase = _semantic.to_tensor(phase)
64
+ pred = _semantic.to_tensor(pred)
65
+ deps = [x.handle for x in deps]
66
+ _semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
67
+
68
+
69
+ @builtin
70
+ def arrive(mbarrier, *, pred=True, _semantic=None):
71
+ """
72
+ Arrive on an mbarrier, signaling that a thread has reached the barrier.
73
+
74
+ Args:
75
+ mbarrier (shared_memory_descriptor): The barrier object to arrive on.
76
+ pred (bool): Predicate. Operation is skipped if predicate is False. Defaults to True.
77
+ """
78
+ count = 1
79
+ pred = _semantic.to_tensor(pred)
80
+ _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)
@@ -2,21 +2,26 @@ from __future__ import annotations
2
2
  from typing import Optional, Tuple, List, TYPE_CHECKING
3
3
 
4
4
  from dataclasses import dataclass
5
+ from triton.runtime.jit import constexpr_function
5
6
  from triton.experimental.gluon.language import _core as ttgl
6
7
  from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
8
+ from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta
7
9
  from triton.experimental.gluon.language._semantic import _check
8
10
 
9
11
  from . import tma
10
- from ..hopper import mbarrier, fence_async_shared
12
+ from ..hopper import fence_async_shared, mbarrier
13
+ from ..ampere import async_copy
11
14
 
15
+ from triton._C.libtriton import ir
12
16
  if TYPE_CHECKING:
13
17
  from triton._C.libtriton.gluon_ir import GluonOpBuilder
14
- from triton._C.libtriton import gluon_ir as ir
15
18
  from ..._semantic import GluonSemantic
16
19
 
17
20
  __all__ = [
18
21
  "allocate_tensor_memory",
22
+ "async_copy",
19
23
  "fence_async_shared",
24
+ "get_tmem_32x32b_reg_layout",
20
25
  "mbarrier",
21
26
  "tensor_memory_descriptor",
22
27
  "TensorMemoryLayout",
@@ -26,6 +31,14 @@ __all__ = [
26
31
 
27
32
  @dataclass(frozen=True, eq=True)
28
33
  class TensorMemoryLayout:
34
+ """
35
+ Describes the layout for tensor memory in Blackwell architecture.
36
+
37
+ Args:
38
+ block (Tuple[int, int]): Tiling block dimensions (M/rows, N/cols).
39
+ unpacked (bool): For sub-32 bit elements, whether they are unpacked to 32 bits.
40
+ cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None.
41
+ """
29
42
  block: Tuple[int, int]
30
43
  unpacked: bool
31
44
  cta_split_num: Optional[Tuple[int, int]] = None
@@ -49,6 +62,74 @@ class TensorMemoryLayout:
49
62
  return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
50
63
 
51
64
 
65
+ @dataclass(frozen=True, eq=True)
66
+ class TensorMemoryScalesLayout:
67
+ """
68
+ Describes the layout for tensor memory scales in Blackwell architecture.
69
+
70
+ Args:
71
+ cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None.
72
+ """
73
+ cta_split_num: Optional[Tuple[int, int]] = None
74
+
75
+ def __post_init__(self):
76
+ assert self.cta_split_num is None or len(self.cta_split_num) == 2
77
+
78
+ def _to_ir(self, builder):
79
+ cta_split_num = self.cta_split_num or [1, 1]
80
+ return builder.get_tensor_memory_scales_layout(cta_split_num, )
81
+
82
+ def mangle(self) -> str:
83
+ cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
84
+ return f"TLS{cta_split_str}TLS"
85
+
86
+
87
+ @constexpr_function
88
+ def _cdiv(x, div):
89
+ return (x + div - 1) // div
90
+
91
+
92
+ @constexpr_function
93
+ def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None):
94
+ """Returns a BlockedLayout compatible with load/store on tensor memory with the 32x32b instruction variant.
95
+ """
96
+ assert len(shape) == 2, "expected a 2D tensor"
97
+ assert num_warps in [4, 8], "expected 4 or 8 warps"
98
+
99
+ shape_per_cta = _get_shape_per_cta(shape, cta_split_num)
100
+ blocks_per_tile = [shape_per_cta[0] // M, shape_per_cta[1] // N]
101
+ num_blocks = blocks_per_tile[0] * blocks_per_tile[1]
102
+
103
+ num_warp_groups = num_warps // 4
104
+ if M == 64:
105
+ threads_per_warp = [16, 2]
106
+ if num_blocks == 1:
107
+ size_per_thread = [1, _cdiv(N, num_warp_groups * 2)]
108
+ warps_per_cta = [4, num_warp_groups]
109
+ else:
110
+ size_per_thread = [1, _cdiv(N, 2)]
111
+ warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)]
112
+ warps_per_cta.append(_cdiv(num_warp_groups, warps_per_cta[0] // 4))
113
+ else:
114
+ if shape[0] > 128:
115
+ size_per_thread = [1, N]
116
+ threads_per_warp = [32, 1]
117
+ warps_per_cta = [4 * num_warp_groups, 1]
118
+ else:
119
+ size_per_thread = [1, _cdiv(N, num_warp_groups)]
120
+ threads_per_warp = [32, 1]
121
+ warps_per_cta = [4, num_warp_groups]
122
+ return BlockedLayout(
123
+ size_per_thread=size_per_thread,
124
+ threads_per_warp=threads_per_warp,
125
+ warps_per_cta=warps_per_cta,
126
+ order=[0, 1],
127
+ ctas_per_cga=ctas_per_cga,
128
+ cta_split_num=cta_split_num,
129
+ cta_order=cta_order,
130
+ )
131
+
132
+
52
133
  class tensor_memory_descriptor_type(base_type):
53
134
 
54
135
  def __init__(self, element_ty, shape, layout, alloc_shape):
@@ -56,7 +137,7 @@ class tensor_memory_descriptor_type(base_type):
56
137
  self.shape = shape
57
138
  self.layout = layout
58
139
  self.alloc_shape = alloc_shape
59
- assert isinstance(layout, TensorMemoryLayout)
140
+ assert isinstance(layout, TensorMemoryLayout) or isinstance(layout, TensorMemoryScalesLayout)
60
141
 
61
142
  def to_ir(self, builder: GluonOpBuilder) -> None:
62
143
  return builder.get_tensor_mem_desc_ty(
@@ -89,6 +170,9 @@ class tensor_memory_descriptor_type(base_type):
89
170
 
90
171
 
91
172
  class tensor_memory_descriptor(base_value):
173
+ """
174
+ Represents a tensor memory descriptor handle for Tensor Core Gen5 operations.
175
+ """
92
176
 
93
177
  def __init__(self, handle, element_ty, shape, layout, alloc_shape):
94
178
  self.handle = handle
@@ -118,6 +202,15 @@ class tensor_memory_descriptor(base_value):
118
202
 
119
203
  @builtin
120
204
  def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor:
205
+ """
206
+ Load a tensor from tensor memory.
207
+
208
+ Args:
209
+ layout (DistributedLayout): Destination layout of the tensor.
210
+
211
+ Returns:
212
+ tensor: A distributed tensor containing the loaded data.
213
+ """
121
214
  layout = _unwrap_if_constexpr(layout)
122
215
  ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout)
123
216
  builder = _semantic.builder
@@ -126,12 +219,31 @@ class tensor_memory_descriptor(base_value):
126
219
 
127
220
  @builtin
128
221
  def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
222
+ """
223
+ Store a tensor into tensor memory.
224
+
225
+ Args:
226
+ value (tensor): The tensor to store.
227
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
228
+ """
129
229
  pred = _unwrap_if_constexpr(pred)
130
230
  pred = _semantic.to_tensor(pred)
231
+ assert value.shape == self.shape, f"source shape {value.shape} does not match destination shape {self.shape}"
232
+ assert value.dtype == self.dtype, f"source dtype {value.dtype} does not match destination dtype {self.dtype}"
131
233
  _semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
132
234
 
133
235
  @builtin
134
236
  def slice(self, start, length, _semantic: GluonSemantic) -> None:
237
+ """
238
+ Create a slice of the tensor memory descriptor along the last dimension.
239
+
240
+ Args:
241
+ start (int): The starting index for subslice.
242
+ length (int): The length of the subslice.
243
+
244
+ Returns:
245
+ tensor_memory_descriptor: Descriptor for the subslice.
246
+ """
135
247
  start = _unwrap_if_constexpr(start)
136
248
  length = _unwrap_if_constexpr(length)
137
249
  _check(isinstance(start, int), lambda: "start must be a constant int")
@@ -147,18 +259,36 @@ class tensor_memory_descriptor(base_value):
147
259
 
148
260
  @builtin
149
261
  def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
262
+ """
263
+ Create a subview of tensor memory by indexing the first dimension.
264
+
265
+ Args:
266
+ index (tensor): The index tensor for the subview.
267
+
268
+ Returns:
269
+ tensor_memory_descriptor: Descriptor for the indexed subview.
270
+ """
150
271
  index = _semantic.to_tensor(index)
151
272
  builder = _semantic.builder
152
- offsets = [builder.get_int32(0)] * self.rank
153
- offsets[0] = index.handle
154
273
  shape = self.shape[1:]
155
274
  layout = self.layout
156
275
  ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
157
- ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets)
276
+ ret.handle = builder.create_memdesc_index(ret.type.to_ir(builder), self.handle, index.handle)
158
277
  return ret
159
278
 
160
279
  @builtin
161
280
  def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
281
+ """
282
+ Reinterpret tensor memory descriptor with a new dtype, shape, and layout.
283
+
284
+ Args:
285
+ dtype (dtype): The new data type.
286
+ shape (Sequence[int]): The new shape.
287
+ layout (TensorMemoryLayout): The new layout.
288
+
289
+ Returns:
290
+ tensor_memory_descriptor: Descriptor with updated type and layout.
291
+ """
162
292
  dtype = _unwrap_if_constexpr(dtype)
163
293
  shape = [_unwrap_if_constexpr(s) for s in shape]
164
294
  layout = _unwrap_if_constexpr(layout)
@@ -170,6 +300,18 @@ class tensor_memory_descriptor(base_value):
170
300
 
171
301
  @builtin
172
302
  def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None):
303
+ """
304
+ Allocate tensor memory.
305
+
306
+ Args:
307
+ element_ty (dtype): The element data type.
308
+ shape (Sequence[int]): The descriptor shape.
309
+ layout (TensorMemoryLayout): The layout of the tensor memory.
310
+ value (tensor, optional): Initial tensor to copy. Defaults to None.
311
+
312
+ Returns:
313
+ tensor_memory_descriptor: Descriptor for the allocated memory.
314
+ """
173
315
  element_ty = _unwrap_if_constexpr(element_ty)
174
316
  shape = _unwrap_if_constexpr(shape)
175
317
  layout = _unwrap_if_constexpr(layout)
@@ -181,8 +323,38 @@ def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None
181
323
  return tensor_memory_descriptor(handle, element_ty, shape, layout, shape)
182
324
 
183
325
 
326
+ @builtin
327
+ def tcgen05_copy(src, dst, _semantic=None):
328
+ """
329
+ Start an asynchronous copy from shared memory to tensor memory.
330
+
331
+ WARNING: The current semantics of the instruction are not well defined and
332
+ the API will change in the future. Use at your own risk.
333
+
334
+ Args:
335
+ src (shared_memory_descriptor): Shared memory to copy from.
336
+ dst (tensor_memory_descriptor): Tensor memory to copy to.
337
+ """
338
+ assert isinstance(src, ttgl.shared_memory_descriptor), "source must be a shared memory descriptor"
339
+ assert isinstance(dst, tensor_memory_descriptor), "destination must be a tensor memory descriptor"
340
+ _semantic.builder.create_tmem_copy(src.handle, dst.handle)
341
+
342
+
184
343
  @builtin
185
344
  def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _semantic=None):
345
+ """
346
+ Emit a 5th generation TensorCore MMA instruction.
347
+ acc = a * b + (acc if use_acc else 0)
348
+
349
+ Args:
350
+ a (shared_memory_descriptor): Left hand side operand in shared memory.
351
+ b (shared_memory_descriptor or tensor_memory_descriptor): Right hand side operand in shared or tensor memory.
352
+ acc (tensor_memory_descriptor): Accumulator value in tensor memory (mutated).
353
+ use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
354
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
355
+ mbarriers (Sequence[shared_memory_descriptor], optional): Barriers to signal when the operation is complete. If None, mma is synchronous. Defaults to None.
356
+ mbarrier_preds (Sequence[bool], optional): Predicates for barriers. Defaults to None.
357
+ """
186
358
  use_acc = _semantic.to_tensor(use_acc)
187
359
  pred = _semantic.to_tensor(pred)
188
360
 
@@ -194,9 +366,22 @@ def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_
194
366
  mbarriers = [bar.handle for bar in mbarriers]
195
367
  if mbarrier_preds is None:
196
368
  true = _semantic.to_tensor(True)
197
- mbarrier_preds = [true] * len(mbarriers)
369
+ mbarrier_preds = [true.handle] * len(mbarriers)
198
370
  else:
199
371
  mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False)
200
372
 
201
373
  _semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers,
202
374
  mbarrier_preds)
375
+
376
+
377
+ @builtin
378
+ def tcgen05_commit(barrier, _semantic=None):
379
+ """
380
+ This instruction causes the provided mbarrier to be arrived-on with a count
381
+ of 1 when all async tcgen05 MMA and copy instructions previously issued by
382
+ the thread are complete.
383
+
384
+ Args:
385
+ barrier (shared_memory_descriptor): The barrier to track completion of tcgen05 MMA and copy instructions.
386
+ """
387
+ _semantic.builder.create_tcgen05_commit(barrier.handle)
@@ -20,6 +20,17 @@ __all__ = [
20
20
 
21
21
  @builtin
22
22
  def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None):
23
+ """
24
+ Asynchronously gather elements from global memory to shared memory using TMA.
25
+
26
+ Args:
27
+ tensor_desc (tensor_descriptor): The tensor descriptor.
28
+ x_offsets (tensor): 1D tensor of X offsets.
29
+ y_offset (int): Scalar Y offset.
30
+ barrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete.
31
+ result (tensor_memory_descriptor): Result shared memory, must have NVMMASharedLayout.
32
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
33
+ """
23
34
  pred = _semantic.to_tensor(pred)
24
35
  y_offset = _semantic.to_tensor(y_offset)
25
36
  _semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle,
@@ -28,5 +39,14 @@ def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _
28
39
 
29
40
  @builtin
30
41
  def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None):
42
+ """
43
+ Asynchronously scatter elements from shared memory to global memory using TMA.
44
+
45
+ Args:
46
+ tensor_desc (tensor_descriptor): The tensor descriptor.
47
+ x_offsets (tensor): 1D tensor of X offsets.
48
+ y_offset (int): Scalar Y offset.
49
+ src (tensor_memory_descriptor): The source data, must be in NVMMASharedLayout.
50
+ """
31
51
  y_offset = _semantic.to_tensor(y_offset)
32
52
  _semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle)