triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Optional, Tuple, List, TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from triton.runtime.jit import constexpr_function
|
|
6
|
+
from triton.experimental.gluon.language import _core as ttgl
|
|
7
|
+
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
|
|
8
|
+
from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta
|
|
9
|
+
from triton.experimental.gluon.language._semantic import _check
|
|
10
|
+
|
|
11
|
+
from . import tma
|
|
12
|
+
from ..hopper import fence_async_shared, mbarrier
|
|
13
|
+
from ..ampere import async_copy
|
|
14
|
+
|
|
15
|
+
from triton._C.libtriton import ir
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from triton._C.libtriton.gluon_ir import GluonOpBuilder
|
|
18
|
+
from ..._semantic import GluonSemantic
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"allocate_tensor_memory",
|
|
22
|
+
"async_copy",
|
|
23
|
+
"fence_async_shared",
|
|
24
|
+
"get_tmem_32x32b_reg_layout",
|
|
25
|
+
"mbarrier",
|
|
26
|
+
"tensor_memory_descriptor",
|
|
27
|
+
"TensorMemoryLayout",
|
|
28
|
+
"tma",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True, eq=True)
|
|
33
|
+
class TensorMemoryLayout:
|
|
34
|
+
"""
|
|
35
|
+
Describes the layout for tensor memory in Blackwell architecture.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
block (Tuple[int, int]): Tiling block dimensions (M/rows, N/cols).
|
|
39
|
+
unpacked (bool): For sub-32 bit elements, whether they are unpacked to 32 bits.
|
|
40
|
+
cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None.
|
|
41
|
+
"""
|
|
42
|
+
block: Tuple[int, int]
|
|
43
|
+
unpacked: bool
|
|
44
|
+
cta_split_num: Optional[Tuple[int, int]] = None
|
|
45
|
+
|
|
46
|
+
def __post_init__(self):
|
|
47
|
+
assert len(self.block) == 2
|
|
48
|
+
assert self.cta_split_num is None or len(self.cta_split_num) == 2
|
|
49
|
+
|
|
50
|
+
def _to_ir(self, builder):
|
|
51
|
+
cta_split_num = self.cta_split_num or [1, 1]
|
|
52
|
+
return builder.get_tensor_memory_layout(
|
|
53
|
+
self.block,
|
|
54
|
+
self.unpacked,
|
|
55
|
+
cta_split_num,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def mangle(self) -> str:
|
|
59
|
+
block_str = f"{self.block[0]}x{self.block[1]}"
|
|
60
|
+
unpacked_str = "U" if self.unpacked else "P"
|
|
61
|
+
cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
|
|
62
|
+
return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass(frozen=True, eq=True)
|
|
66
|
+
class TensorMemoryScalesLayout:
|
|
67
|
+
"""
|
|
68
|
+
Describes the layout for tensor memory scales in Blackwell architecture.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None.
|
|
72
|
+
"""
|
|
73
|
+
cta_split_num: Optional[Tuple[int, int]] = None
|
|
74
|
+
|
|
75
|
+
def __post_init__(self):
|
|
76
|
+
assert self.cta_split_num is None or len(self.cta_split_num) == 2
|
|
77
|
+
|
|
78
|
+
def _to_ir(self, builder):
|
|
79
|
+
cta_split_num = self.cta_split_num or [1, 1]
|
|
80
|
+
return builder.get_tensor_memory_scales_layout(cta_split_num, )
|
|
81
|
+
|
|
82
|
+
def mangle(self) -> str:
|
|
83
|
+
cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
|
|
84
|
+
return f"TLS{cta_split_str}TLS"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@constexpr_function
|
|
88
|
+
def _cdiv(x, div):
|
|
89
|
+
return (x + div - 1) // div
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@constexpr_function
|
|
93
|
+
def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None):
|
|
94
|
+
"""Returns a BlockedLayout compatible with load/store on tensor memory with the 32x32b instruction variant.
|
|
95
|
+
"""
|
|
96
|
+
assert len(shape) == 2, "expected a 2D tensor"
|
|
97
|
+
assert num_warps in [4, 8], "expected 4 or 8 warps"
|
|
98
|
+
|
|
99
|
+
shape_per_cta = _get_shape_per_cta(shape, cta_split_num)
|
|
100
|
+
blocks_per_tile = [shape_per_cta[0] // M, shape_per_cta[1] // N]
|
|
101
|
+
num_blocks = blocks_per_tile[0] * blocks_per_tile[1]
|
|
102
|
+
|
|
103
|
+
num_warp_groups = num_warps // 4
|
|
104
|
+
if M == 64:
|
|
105
|
+
threads_per_warp = [16, 2]
|
|
106
|
+
if num_blocks == 1:
|
|
107
|
+
size_per_thread = [1, _cdiv(N, num_warp_groups * 2)]
|
|
108
|
+
warps_per_cta = [4, num_warp_groups]
|
|
109
|
+
else:
|
|
110
|
+
size_per_thread = [1, _cdiv(N, 2)]
|
|
111
|
+
warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)]
|
|
112
|
+
warps_per_cta.append(_cdiv(num_warp_groups, warps_per_cta[0] // 4))
|
|
113
|
+
else:
|
|
114
|
+
if shape[0] > 128:
|
|
115
|
+
size_per_thread = [1, N]
|
|
116
|
+
threads_per_warp = [32, 1]
|
|
117
|
+
warps_per_cta = [4 * num_warp_groups, 1]
|
|
118
|
+
else:
|
|
119
|
+
size_per_thread = [1, _cdiv(N, num_warp_groups)]
|
|
120
|
+
threads_per_warp = [32, 1]
|
|
121
|
+
warps_per_cta = [4, num_warp_groups]
|
|
122
|
+
return BlockedLayout(
|
|
123
|
+
size_per_thread=size_per_thread,
|
|
124
|
+
threads_per_warp=threads_per_warp,
|
|
125
|
+
warps_per_cta=warps_per_cta,
|
|
126
|
+
order=[0, 1],
|
|
127
|
+
ctas_per_cga=ctas_per_cga,
|
|
128
|
+
cta_split_num=cta_split_num,
|
|
129
|
+
cta_order=cta_order,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class tensor_memory_descriptor_type(base_type):
|
|
134
|
+
|
|
135
|
+
def __init__(self, element_ty, shape, layout, alloc_shape):
|
|
136
|
+
self.element_ty = element_ty
|
|
137
|
+
self.shape = shape
|
|
138
|
+
self.layout = layout
|
|
139
|
+
self.alloc_shape = alloc_shape
|
|
140
|
+
assert isinstance(layout, TensorMemoryLayout) or isinstance(layout, TensorMemoryScalesLayout)
|
|
141
|
+
|
|
142
|
+
def to_ir(self, builder: GluonOpBuilder) -> None:
|
|
143
|
+
return builder.get_tensor_mem_desc_ty(
|
|
144
|
+
self.element_ty.to_ir(builder),
|
|
145
|
+
self.shape,
|
|
146
|
+
self.layout._to_ir(builder),
|
|
147
|
+
self.alloc_shape,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[tensor_memory_descriptor, int]:
|
|
151
|
+
value = tensor_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
|
|
152
|
+
return value, cursor + 1
|
|
153
|
+
|
|
154
|
+
def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
|
|
155
|
+
out.append(self.to_ir(builder))
|
|
156
|
+
|
|
157
|
+
def __str__(self) -> str:
|
|
158
|
+
return f"tensor_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>"
|
|
159
|
+
|
|
160
|
+
def __eq__(self, other) -> bool:
|
|
161
|
+
return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
|
|
162
|
+
and self.alloc_shape == other.alloc_shape)
|
|
163
|
+
|
|
164
|
+
def __neq__(self, other) -> bool:
|
|
165
|
+
return not (self == other)
|
|
166
|
+
|
|
167
|
+
def mangle(self) -> str:
|
|
168
|
+
shape_str = "_".join([str(s) for s in self.shape])
|
|
169
|
+
return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class tensor_memory_descriptor(base_value):
|
|
173
|
+
"""
|
|
174
|
+
Represents a tensor memory descriptor handle for Tensor Core Gen5 operations.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(self, handle, element_ty, shape, layout, alloc_shape):
|
|
178
|
+
self.handle = handle
|
|
179
|
+
self.type = tensor_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
|
|
180
|
+
|
|
181
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
182
|
+
handles.append(self.handle)
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def dtype(self):
|
|
186
|
+
return self.type.element_ty
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def shape(self):
|
|
190
|
+
return self.type.shape
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def rank(self):
|
|
194
|
+
return len(self.shape)
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def layout(self):
|
|
198
|
+
return self.type.layout
|
|
199
|
+
|
|
200
|
+
def __str__(self) -> str:
|
|
201
|
+
return str(self.type)
|
|
202
|
+
|
|
203
|
+
@builtin
|
|
204
|
+
def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor:
|
|
205
|
+
"""
|
|
206
|
+
Load a tensor from tensor memory.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
layout (DistributedLayout): Destination layout of the tensor.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
tensor: A distributed tensor containing the loaded data.
|
|
213
|
+
"""
|
|
214
|
+
layout = _unwrap_if_constexpr(layout)
|
|
215
|
+
ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout)
|
|
216
|
+
builder = _semantic.builder
|
|
217
|
+
handle = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle)
|
|
218
|
+
return ttgl.tensor(handle, ret_ty)
|
|
219
|
+
|
|
220
|
+
@builtin
|
|
221
|
+
def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
|
|
222
|
+
"""
|
|
223
|
+
Store a tensor into tensor memory.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
value (tensor): The tensor to store.
|
|
227
|
+
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
|
|
228
|
+
"""
|
|
229
|
+
pred = _unwrap_if_constexpr(pred)
|
|
230
|
+
pred = _semantic.to_tensor(pred)
|
|
231
|
+
assert value.shape == self.shape, f"source shape {value.shape} does not match destination shape {self.shape}"
|
|
232
|
+
assert value.dtype == self.dtype, f"source dtype {value.dtype} does not match destination dtype {self.dtype}"
|
|
233
|
+
_semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
|
|
234
|
+
|
|
235
|
+
@builtin
|
|
236
|
+
def slice(self, start, length, _semantic: GluonSemantic) -> None:
|
|
237
|
+
"""
|
|
238
|
+
Create a slice of the tensor memory descriptor along the last dimension.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
start (int): The starting index for subslice.
|
|
242
|
+
length (int): The length of the subslice.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
tensor_memory_descriptor: Descriptor for the subslice.
|
|
246
|
+
"""
|
|
247
|
+
start = _unwrap_if_constexpr(start)
|
|
248
|
+
length = _unwrap_if_constexpr(length)
|
|
249
|
+
_check(isinstance(start, int), lambda: "start must be a constant int")
|
|
250
|
+
_check(isinstance(length, int), lambda: "length must be a constant int")
|
|
251
|
+
shape = self.shape[:-1] + [length]
|
|
252
|
+
layout = self.type.layout
|
|
253
|
+
layout = TensorMemoryLayout((layout.block[0], min(layout.block[1], length)), layout.unpacked,
|
|
254
|
+
layout.cta_split_num)
|
|
255
|
+
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
|
|
256
|
+
builder = _semantic.builder
|
|
257
|
+
ret.handle = builder.create_tmem_subslice(ret.type.to_ir(builder), self.handle, start)
|
|
258
|
+
return ret
|
|
259
|
+
|
|
260
|
+
@builtin
|
|
261
|
+
def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
|
|
262
|
+
"""
|
|
263
|
+
Create a subview of tensor memory by indexing the first dimension.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
index (tensor): The index tensor for the subview.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
tensor_memory_descriptor: Descriptor for the indexed subview.
|
|
270
|
+
"""
|
|
271
|
+
index = _semantic.to_tensor(index)
|
|
272
|
+
builder = _semantic.builder
|
|
273
|
+
shape = self.shape[1:]
|
|
274
|
+
layout = self.layout
|
|
275
|
+
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
|
|
276
|
+
ret.handle = builder.create_memdesc_index(ret.type.to_ir(builder), self.handle, index.handle)
|
|
277
|
+
return ret
|
|
278
|
+
|
|
279
|
+
@builtin
|
|
280
|
+
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
|
|
281
|
+
"""
|
|
282
|
+
Reinterpret tensor memory descriptor with a new dtype, shape, and layout.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
dtype (dtype): The new data type.
|
|
286
|
+
shape (Sequence[int]): The new shape.
|
|
287
|
+
layout (TensorMemoryLayout): The new layout.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
tensor_memory_descriptor: Descriptor with updated type and layout.
|
|
291
|
+
"""
|
|
292
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
293
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
294
|
+
layout = _unwrap_if_constexpr(layout)
|
|
295
|
+
|
|
296
|
+
ty = tensor_memory_descriptor_type(dtype, shape, layout, shape)
|
|
297
|
+
handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle)
|
|
298
|
+
return tensor_memory_descriptor(handle, **ty.__dict__)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@builtin
|
|
302
|
+
def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None):
|
|
303
|
+
"""
|
|
304
|
+
Allocate tensor memory.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
element_ty (dtype): The element data type.
|
|
308
|
+
shape (Sequence[int]): The descriptor shape.
|
|
309
|
+
layout (TensorMemoryLayout): The layout of the tensor memory.
|
|
310
|
+
value (tensor, optional): Initial tensor to copy. Defaults to None.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
tensor_memory_descriptor: Descriptor for the allocated memory.
|
|
314
|
+
"""
|
|
315
|
+
element_ty = _unwrap_if_constexpr(element_ty)
|
|
316
|
+
shape = _unwrap_if_constexpr(shape)
|
|
317
|
+
layout = _unwrap_if_constexpr(layout)
|
|
318
|
+
value = value.handle if value is not None else None
|
|
319
|
+
|
|
320
|
+
ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape)
|
|
321
|
+
builder = _semantic.builder
|
|
322
|
+
handle = builder.create_tmem_alloc(ty.to_ir(builder), value)
|
|
323
|
+
return tensor_memory_descriptor(handle, element_ty, shape, layout, shape)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@builtin
|
|
327
|
+
def tcgen05_copy(src, dst, _semantic=None):
|
|
328
|
+
"""
|
|
329
|
+
Start an asynchronous copy from shared memory to tensor memory.
|
|
330
|
+
|
|
331
|
+
WARNING: The current semantics of the instruction are not well defined and
|
|
332
|
+
the API will change in the future. Use at your own risk.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
src (shared_memory_descriptor): Shared memory to copy from.
|
|
336
|
+
dst (tensor_memory_descriptor): Tensor memory to copy to.
|
|
337
|
+
"""
|
|
338
|
+
assert isinstance(src, ttgl.shared_memory_descriptor), "source must be a shared memory descriptor"
|
|
339
|
+
assert isinstance(dst, tensor_memory_descriptor), "destination must be a tensor memory descriptor"
|
|
340
|
+
_semantic.builder.create_tmem_copy(src.handle, dst.handle)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
@builtin
|
|
344
|
+
def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _semantic=None):
|
|
345
|
+
"""
|
|
346
|
+
Emit a 5th generation TensorCore MMA instruction.
|
|
347
|
+
acc = a * b + (acc if use_acc else 0)
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
a (shared_memory_descriptor): Left hand side operand in shared memory.
|
|
351
|
+
b (shared_memory_descriptor or tensor_memory_descriptor): Right hand side operand in shared or tensor memory.
|
|
352
|
+
acc (tensor_memory_descriptor): Accumulator value in tensor memory (mutated).
|
|
353
|
+
use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
|
|
354
|
+
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
|
|
355
|
+
mbarriers (Sequence[shared_memory_descriptor], optional): Barriers to signal when the operation is complete. If None, mma is synchronous. Defaults to None.
|
|
356
|
+
mbarrier_preds (Sequence[bool], optional): Predicates for barriers. Defaults to None.
|
|
357
|
+
"""
|
|
358
|
+
use_acc = _semantic.to_tensor(use_acc)
|
|
359
|
+
pred = _semantic.to_tensor(pred)
|
|
360
|
+
|
|
361
|
+
if mbarriers is None:
|
|
362
|
+
assert mbarrier_preds is None
|
|
363
|
+
mbarriers = []
|
|
364
|
+
mbarrier_preds = []
|
|
365
|
+
else:
|
|
366
|
+
mbarriers = [bar.handle for bar in mbarriers]
|
|
367
|
+
if mbarrier_preds is None:
|
|
368
|
+
true = _semantic.to_tensor(True)
|
|
369
|
+
mbarrier_preds = [true.handle] * len(mbarriers)
|
|
370
|
+
else:
|
|
371
|
+
mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False)
|
|
372
|
+
|
|
373
|
+
_semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers,
|
|
374
|
+
mbarrier_preds)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@builtin
|
|
378
|
+
def tcgen05_commit(barrier, _semantic=None):
|
|
379
|
+
"""
|
|
380
|
+
This instruction causes the provided mbarrier to be arrived-on with a count
|
|
381
|
+
of 1 when all async tcgen05 MMA and copy instructions previously issued by
|
|
382
|
+
the thread are complete.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
barrier (shared_memory_descriptor): The barrier to track completion of tcgen05 MMA and copy instructions.
|
|
386
|
+
"""
|
|
387
|
+
_semantic.builder.create_tcgen05_commit(barrier.handle)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from triton.experimental.gluon.language._core import builtin
|
|
2
|
+
from triton.experimental.gluon.language.nvidia.hopper.tma import (
|
|
3
|
+
async_copy_global_to_shared,
|
|
4
|
+
async_copy_shared_to_global,
|
|
5
|
+
store_wait,
|
|
6
|
+
tensor_descriptor,
|
|
7
|
+
tensor_descriptor_type,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"async_gather",
|
|
12
|
+
"async_scatter",
|
|
13
|
+
"async_copy_global_to_shared",
|
|
14
|
+
"async_copy_shared_to_global",
|
|
15
|
+
"store_wait",
|
|
16
|
+
"tensor_descriptor",
|
|
17
|
+
"tensor_descriptor_type",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@builtin
|
|
22
|
+
def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None):
|
|
23
|
+
"""
|
|
24
|
+
Asynchronously gather elements from global memory to shared memory using TMA.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
tensor_desc (tensor_descriptor): The tensor descriptor.
|
|
28
|
+
x_offsets (tensor): 1D tensor of X offsets.
|
|
29
|
+
y_offset (int): Scalar Y offset.
|
|
30
|
+
barrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete.
|
|
31
|
+
result (tensor_memory_descriptor): Result shared memory, must have NVMMASharedLayout.
|
|
32
|
+
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
|
|
33
|
+
"""
|
|
34
|
+
pred = _semantic.to_tensor(pred)
|
|
35
|
+
y_offset = _semantic.to_tensor(y_offset)
|
|
36
|
+
_semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle,
|
|
37
|
+
result.handle, pred.handle)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@builtin
|
|
41
|
+
def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None):
|
|
42
|
+
"""
|
|
43
|
+
Asynchronously scatter elements from shared memory to global memory using TMA.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
tensor_desc (tensor_descriptor): The tensor descriptor.
|
|
47
|
+
x_offsets (tensor): 1D tensor of X offsets.
|
|
48
|
+
y_offset (int): Scalar Y offset.
|
|
49
|
+
src (tensor_memory_descriptor): The source data, must be in NVMMASharedLayout.
|
|
50
|
+
"""
|
|
51
|
+
y_offset = _semantic.to_tensor(y_offset)
|
|
52
|
+
_semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from triton.compiler.code_generator import unflatten_ir_values
|
|
3
|
+
from ..ampere import async_copy
|
|
4
|
+
from . import mbarrier, tma
|
|
5
|
+
from ... import _core
|
|
6
|
+
|
|
7
|
+
from typing import List, Tuple, TYPE_CHECKING
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from triton._C.libtriton import ir
|
|
10
|
+
|
|
11
|
+
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@_core.builtin
|
|
15
|
+
def fence_async_shared(cluster=False, _semantic=None):
|
|
16
|
+
"""
|
|
17
|
+
Issue a fence to complete asynchronous shared memory operations.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
cluster (bool): Whether to fence across cluster. Defaults to False.
|
|
21
|
+
"""
|
|
22
|
+
cluster = _core._unwrap_if_constexpr(cluster)
|
|
23
|
+
_semantic.builder.create_fence_async_shared(cluster)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class warpgroup_mma_accumulator_type(_core.base_type):
|
|
27
|
+
tensor_type: _core.dtype
|
|
28
|
+
|
|
29
|
+
def __init__(self, tensor_type: _core.dtype):
|
|
30
|
+
self.tensor_type = tensor_type
|
|
31
|
+
|
|
32
|
+
def __str__(self) -> str:
|
|
33
|
+
return f"warpgroup_mma_accumulator<{self.tensor_type}>"
|
|
34
|
+
|
|
35
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[warpgroup_mma_accumulator, int]:
|
|
36
|
+
return warpgroup_mma_accumulator(handles[cursor], self.tensor_type), cursor + 1
|
|
37
|
+
|
|
38
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
39
|
+
self.tensor_type._flatten_ir_types(builder, out)
|
|
40
|
+
|
|
41
|
+
def __eq__(self, other) -> bool:
|
|
42
|
+
return type(self) is type(other) and self.tensor_type == other.tensor_type
|
|
43
|
+
|
|
44
|
+
def mangle(self) -> str:
|
|
45
|
+
return f"FT{self.tensor_type.mangle()}FT"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class warpgroup_mma_accumulator(_core.base_value):
|
|
49
|
+
handle: ir.value
|
|
50
|
+
type: warpgroup_mma_accumulator_type
|
|
51
|
+
|
|
52
|
+
def __init__(self, handle, tensor_type: _core.dtype):
|
|
53
|
+
self.handle = handle
|
|
54
|
+
self.type = warpgroup_mma_accumulator_type(tensor_type)
|
|
55
|
+
|
|
56
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
57
|
+
handles.append(self.handle)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@_core.builtin
|
|
61
|
+
def warpgroup_mma_init(value, _semantic):
|
|
62
|
+
assert isinstance(value, _core.tensor)
|
|
63
|
+
return warpgroup_mma_accumulator(value.handle, value.type)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@_core.builtin
|
|
67
|
+
def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=None, is_async=False,
|
|
68
|
+
_semantic=None):
|
|
69
|
+
"""
|
|
70
|
+
Perform warpgroup MMA (Tensor Core) operations.
|
|
71
|
+
acc = a * b + (acc if use_acc else 0)
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
a (tensor or shared_memory_descriptor): Left hand side operand.
|
|
75
|
+
b (shared_memory_descriptor): Right hand side operand.
|
|
76
|
+
acc (tensor): Accumulator tensor.
|
|
77
|
+
use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
|
|
78
|
+
precision (str, optional): Dot input precision. Defaults to builder default.
|
|
79
|
+
max_num_imprecise_acc (int): Max imprecise accumulations. Used for fp8 -> fp32 dot. Determines how many accumulation are done in limited precision. Defaults to None, which means no upcasting is done.
|
|
80
|
+
is_async (bool): Whether operation is asynchronous. Defaults to False.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
tensor or warpgroup_mma_accumulator: Returns the result if synchronous, or a token to load the value once computed if asynchronous.
|
|
84
|
+
"""
|
|
85
|
+
use_acc = _semantic.to_tensor(use_acc)
|
|
86
|
+
|
|
87
|
+
if precision is None:
|
|
88
|
+
precision = _semantic.builder.options.default_dot_input_precision
|
|
89
|
+
|
|
90
|
+
precision = _semantic._str_to_dot_input_precision(precision)
|
|
91
|
+
|
|
92
|
+
K = a.type.shape[-1]
|
|
93
|
+
if max_num_imprecise_acc is None:
|
|
94
|
+
if a.dtype.is_fp8() and b.dtype.is_fp8():
|
|
95
|
+
max_num_imprecise_acc = _semantic.builder.options.max_num_imprecise_acc_default
|
|
96
|
+
else:
|
|
97
|
+
max_num_imprecise_acc = 0
|
|
98
|
+
else:
|
|
99
|
+
if a.dtype.is_fp8() and b.dtype.is_fp8() and max_num_imprecise_acc > K:
|
|
100
|
+
raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
|
|
101
|
+
|
|
102
|
+
max_num_imprecise_acc = _core._unwrap_if_constexpr(max_num_imprecise_acc)
|
|
103
|
+
is_async = _core._unwrap_if_constexpr(is_async)
|
|
104
|
+
|
|
105
|
+
handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
|
|
106
|
+
max_num_imprecise_acc, is_async)
|
|
107
|
+
tensor_ty = acc.type.tensor_type if isinstance(acc, warpgroup_mma_accumulator) else acc.type
|
|
108
|
+
if is_async:
|
|
109
|
+
return warpgroup_mma_accumulator(handle, tensor_ty)
|
|
110
|
+
else:
|
|
111
|
+
return _core.tensor(handle, tensor_ty)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@_core.builtin
|
|
115
|
+
def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
|
|
116
|
+
"""
|
|
117
|
+
Wait until `num_outstanding` or less warpgroup MMA operations are in-flight.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0.
|
|
121
|
+
deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished.
|
|
122
|
+
"""
|
|
123
|
+
if deps is None:
|
|
124
|
+
raise ValueError("warpgroup_mma_wait deps must be given")
|
|
125
|
+
deps_handles = [x.handle for x in deps] if deps is not None else []
|
|
126
|
+
num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
|
|
127
|
+
results = _semantic.builder.create_warpgroup_mma_wait(deps_handles, num_outstanding)
|
|
128
|
+
result_types = [dep.type.tensor_type if isinstance(dep, warpgroup_mma_accumulator) else dep.type for dep in deps]
|
|
129
|
+
results = unflatten_ir_values(results, result_types)
|
|
130
|
+
if len(deps) == 1:
|
|
131
|
+
return next(results)
|
|
132
|
+
return tuple(results)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from ..ampere.mbarrier import MBarrierLayout, init, invalidate, wait
|
|
2
|
+
from ..._core import _unwrap_if_constexpr, builtin
|
|
3
|
+
|
|
4
|
+
__all__ = ["arrive", "expect", "init", "invalidate", "MBarrierLayout", "wait"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@builtin
|
|
8
|
+
def expect(mbarrier, bytes, pred=True, _semantic=None):
|
|
9
|
+
"""
|
|
10
|
+
Expect a specific number of bytes being copied. When they are copied, the barrier is signaled.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
mbarrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete.
|
|
14
|
+
bytes (int): Expected byte count.
|
|
15
|
+
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
|
|
16
|
+
"""
|
|
17
|
+
bytes = _unwrap_if_constexpr(bytes)
|
|
18
|
+
pred = _semantic.to_tensor(pred)
|
|
19
|
+
_semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@builtin
|
|
23
|
+
def arrive(mbarrier, *, count=1, pred=True, _semantic=None):
|
|
24
|
+
"""
|
|
25
|
+
Arrive at an mbarrier with a specified count.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
mbarrier (shared_memory_descriptor): Barrier to be signalled.
|
|
29
|
+
count (int): Count to arrive with. Defaults to 1.
|
|
30
|
+
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
|
|
31
|
+
"""
|
|
32
|
+
count = _unwrap_if_constexpr(count)
|
|
33
|
+
pred = _semantic.to_tensor(pred)
|
|
34
|
+
_semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)
|