triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,583 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type
|
|
4
|
+
from triton.runtime.jit import constexpr_function
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _realize_cta_layout(layout, rank):
|
|
8
|
+
ctas_per_cga = layout.ctas_per_cga or [1] * rank
|
|
9
|
+
cta_split_num = layout.cta_split_num or [1] * rank
|
|
10
|
+
cta_order = layout.cta_order or list(reversed(range(rank)))
|
|
11
|
+
object.__setattr__(layout, "ctas_per_cga", ctas_per_cga)
|
|
12
|
+
object.__setattr__(layout, "cta_split_num", cta_split_num)
|
|
13
|
+
object.__setattr__(layout, "cta_order", cta_order)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DistributedLayout:
|
|
17
|
+
"""
|
|
18
|
+
Base class for distributed memory layouts in Gluon IR.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def type(self):
|
|
23
|
+
return constexpr_type(self)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class AutoLayout(DistributedLayout):
|
|
28
|
+
|
|
29
|
+
def _to_ir(self, builder):
|
|
30
|
+
return builder.get_auto_layout()
|
|
31
|
+
|
|
32
|
+
def mangle(self):
|
|
33
|
+
return "AL"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class BlockedLayout(DistributedLayout):
|
|
38
|
+
"""
|
|
39
|
+
Represents a blocked layout, partitioning a tensor across threads, warps, and CTAs.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
size_per_thread (List[int]): Number of elements per thread per dimension.
|
|
43
|
+
threads_per_warp (List[int]): Number of threads per warp per dimension.
|
|
44
|
+
warps_per_cta (List[int]): Number of warps per CTA per dimension.
|
|
45
|
+
order (List[int]): The ordering of dimensions for partitioning.
|
|
46
|
+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
|
|
47
|
+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
|
|
48
|
+
cta_order (Optional[List[int]]): Ordering for CTAs.
|
|
49
|
+
"""
|
|
50
|
+
size_per_thread: List[int]
|
|
51
|
+
threads_per_warp: List[int]
|
|
52
|
+
warps_per_cta: List[int]
|
|
53
|
+
order: List[int]
|
|
54
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
55
|
+
cta_split_num: Optional[List[int]] = None
|
|
56
|
+
cta_order: Optional[List[int]] = None
|
|
57
|
+
|
|
58
|
+
def __post_init__(self):
|
|
59
|
+
super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread))
|
|
60
|
+
super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp))
|
|
61
|
+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
|
|
62
|
+
super().__setattr__("order", _unwrap_if_constexpr(self.order))
|
|
63
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
64
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
65
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
66
|
+
|
|
67
|
+
rank = len(self.size_per_thread)
|
|
68
|
+
_realize_cta_layout(self, rank)
|
|
69
|
+
assert len(self.threads_per_warp) == rank
|
|
70
|
+
assert len(self.warps_per_cta) == rank
|
|
71
|
+
assert len(self.order) == rank
|
|
72
|
+
assert len(self.ctas_per_cga) == rank
|
|
73
|
+
assert len(self.cta_split_num) == rank
|
|
74
|
+
assert len(self.cta_order) == rank
|
|
75
|
+
|
|
76
|
+
def _to_ir(self, builder):
|
|
77
|
+
return builder.get_blocked_layout(
|
|
78
|
+
self.size_per_thread,
|
|
79
|
+
self.threads_per_warp,
|
|
80
|
+
self.warps_per_cta,
|
|
81
|
+
self.order,
|
|
82
|
+
self.ctas_per_cga,
|
|
83
|
+
self.cta_split_num,
|
|
84
|
+
self.cta_order,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def mangle(self) -> str:
|
|
88
|
+
|
|
89
|
+
def stringify(x):
|
|
90
|
+
if x is None:
|
|
91
|
+
return ""
|
|
92
|
+
return "_".join(map(str, x))
|
|
93
|
+
|
|
94
|
+
size_per_thread = stringify(self.size_per_thread)
|
|
95
|
+
threads_per_warp = stringify(self.threads_per_warp)
|
|
96
|
+
warps_per_cta = stringify(self.warps_per_cta)
|
|
97
|
+
order = stringify(self.order)
|
|
98
|
+
ctas_per_cga = stringify(self.ctas_per_cga)
|
|
99
|
+
cta_split_num = stringify(self.cta_split_num)
|
|
100
|
+
cta_order = stringify(self.cta_order)
|
|
101
|
+
return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B"
|
|
102
|
+
|
|
103
|
+
def __hash__(self):
|
|
104
|
+
return hash((
|
|
105
|
+
tuple(self.size_per_thread),
|
|
106
|
+
tuple(self.threads_per_warp),
|
|
107
|
+
tuple(self.warps_per_cta),
|
|
108
|
+
tuple(self.order),
|
|
109
|
+
tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
|
|
110
|
+
tuple(self.cta_split_num) if self.cta_split_num else None,
|
|
111
|
+
tuple(self.cta_order) if self.cta_order else None,
|
|
112
|
+
))
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class SliceLayout(DistributedLayout):
|
|
117
|
+
"""
|
|
118
|
+
Represents a layout corresponding to slicing a distributed tensor along one dimension.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
dim (int): The dimension index to slice.
|
|
122
|
+
parent (DistributedLayout): The parent layout before slicing.
|
|
123
|
+
"""
|
|
124
|
+
dim: int
|
|
125
|
+
parent: DistributedLayout
|
|
126
|
+
|
|
127
|
+
def __post_init__(self):
|
|
128
|
+
super().__setattr__("dim", _unwrap_if_constexpr(self.dim))
|
|
129
|
+
super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
|
|
130
|
+
|
|
131
|
+
def _to_ir(self, builder):
|
|
132
|
+
return builder.get_slice_layout(
|
|
133
|
+
self.dim,
|
|
134
|
+
self.parent._to_ir(builder),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def mangle(self) -> str:
|
|
138
|
+
return f"SL{self.dim}_{self.parent.mangle()}SL"
|
|
139
|
+
|
|
140
|
+
def __hash__(self):
|
|
141
|
+
return hash((self.dim, self.parent))
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass(frozen=True)
|
|
145
|
+
class DistributedLinearLayout(DistributedLayout):
|
|
146
|
+
"""
|
|
147
|
+
Represents a linear distributed layout with explicit bases at register, lane, warp, and block levels.
|
|
148
|
+
See: https://arxiv.org/abs/2505.23819 for reference.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
reg_bases (List[List[int]]): Bases for register-level distribution.
|
|
152
|
+
lane_bases (List[List[int]]): Bases for lane-level distribution.
|
|
153
|
+
warp_bases (List[List[int]]): Bases for warp-level distribution.
|
|
154
|
+
block_bases (List[List[int]]): Bases for block-level distribution.
|
|
155
|
+
shape (List[int]): The tensor global shape.
|
|
156
|
+
"""
|
|
157
|
+
reg_bases: List[List[int]]
|
|
158
|
+
lane_bases: List[List[int]]
|
|
159
|
+
warp_bases: List[List[int]]
|
|
160
|
+
block_bases: List[List[int]]
|
|
161
|
+
shape: List[int]
|
|
162
|
+
|
|
163
|
+
def __post_init__(self):
|
|
164
|
+
super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases))
|
|
165
|
+
super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases))
|
|
166
|
+
super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases))
|
|
167
|
+
super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
|
|
168
|
+
super().__setattr__("shape", _unwrap_shape(self.shape))
|
|
169
|
+
|
|
170
|
+
rank = len(self.shape)
|
|
171
|
+
|
|
172
|
+
for basis in self.reg_bases:
|
|
173
|
+
assert len(basis) == rank
|
|
174
|
+
for basis in self.lane_bases:
|
|
175
|
+
assert len(basis) == rank
|
|
176
|
+
for basis in self.warp_bases:
|
|
177
|
+
assert len(basis) == rank
|
|
178
|
+
for basis in self.block_bases:
|
|
179
|
+
assert len(basis) == rank
|
|
180
|
+
|
|
181
|
+
def _to_ir(self, builder):
|
|
182
|
+
return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases,
|
|
183
|
+
self.shape)
|
|
184
|
+
|
|
185
|
+
def mangle(self):
|
|
186
|
+
return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
|
|
187
|
+
|
|
188
|
+
def __hash__(self):
|
|
189
|
+
return hash((
|
|
190
|
+
tuple(map(tuple, self.reg_bases)),
|
|
191
|
+
tuple(map(tuple, self.lane_bases)),
|
|
192
|
+
tuple(map(tuple, self.warp_bases)),
|
|
193
|
+
tuple(map(tuple, self.block_bases)),
|
|
194
|
+
tuple(self.shape),
|
|
195
|
+
))
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@dataclass(frozen=True)
|
|
199
|
+
class DotOperandLayout(DistributedLayout):
|
|
200
|
+
"""
|
|
201
|
+
Represents a layout for a dot operand.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
operand_index (int): 0 for LHS and 1 for RHS of the dot operation.
|
|
205
|
+
parent (DistributedLayout): The parent layout, representing the MMA.
|
|
206
|
+
k_width (int): Number of elements per 32-bits.
|
|
207
|
+
"""
|
|
208
|
+
operand_index: int
|
|
209
|
+
parent: DistributedLayout
|
|
210
|
+
k_width: int
|
|
211
|
+
|
|
212
|
+
def __post_init__(self):
|
|
213
|
+
super().__setattr__("operand_index", _unwrap_if_constexpr(self.operand_index))
|
|
214
|
+
super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
|
|
215
|
+
super().__setattr__("k_width", _unwrap_if_constexpr(self.k_width))
|
|
216
|
+
|
|
217
|
+
def _to_ir(self, builder):
|
|
218
|
+
return builder.get_dot_operand_layout(self.operand_index, self.parent._to_ir(builder), self.k_width)
|
|
219
|
+
|
|
220
|
+
def mangle(self) -> str:
|
|
221
|
+
return f"DO{self.operand_index}_{self.parent.mangle()}_{self.k_width}DO"
|
|
222
|
+
|
|
223
|
+
def __hash__(self):
|
|
224
|
+
return hash((self.operand_index, self.parent, self.k_width))
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@dataclass(frozen=True, eq=True)
|
|
228
|
+
class NVMMADistributedLayout(DistributedLayout):
|
|
229
|
+
"""
|
|
230
|
+
Represents a layout for NVIDIA MMA (tensor core) operations.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
version (List[int]): Version identifier for the MMA instruction.
|
|
234
|
+
warps_per_cta (List[int]): Number of warps per CTA.
|
|
235
|
+
instr_shape (List[int]): Instruction shape for MMA.
|
|
236
|
+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
|
|
237
|
+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
|
|
238
|
+
cta_order (Optional[List[int]]): CTA ordering.
|
|
239
|
+
"""
|
|
240
|
+
version: List[int]
|
|
241
|
+
warps_per_cta: List[int]
|
|
242
|
+
instr_shape: List[int]
|
|
243
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
244
|
+
cta_split_num: Optional[List[int]] = None
|
|
245
|
+
cta_order: Optional[List[int]] = None
|
|
246
|
+
|
|
247
|
+
def __post_init__(self):
|
|
248
|
+
super().__setattr__("version", _unwrap_if_constexpr(self.version))
|
|
249
|
+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
|
|
250
|
+
super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
|
|
251
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
252
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
253
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
254
|
+
|
|
255
|
+
rank = len(self.warps_per_cta)
|
|
256
|
+
_realize_cta_layout(self, rank)
|
|
257
|
+
assert len(self.ctas_per_cga) == rank
|
|
258
|
+
assert len(self.cta_split_num) == rank
|
|
259
|
+
assert len(self.cta_order) == rank
|
|
260
|
+
|
|
261
|
+
def _to_ir(self, builder):
|
|
262
|
+
return builder.get_mma_layout(self.version, self.warps_per_cta, self.ctas_per_cga, self.cta_split_num,
|
|
263
|
+
self.cta_order, self.instr_shape)
|
|
264
|
+
|
|
265
|
+
def mangle(self) -> str:
|
|
266
|
+
return f"MMA_{self.version}_{self.warps_per_cta}_{self.instr_shape}_{self.ctas_per_cga}_{self.cta_split_num}_{self.cta_order}_MMA"
|
|
267
|
+
|
|
268
|
+
def __hash__(self):
|
|
269
|
+
return hash((tuple(self.version), tuple(self.warps_per_cta),
|
|
270
|
+
tuple(self.instr_shape), tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
|
|
271
|
+
tuple(self.cta_split_num) if self.cta_split_num else None,
|
|
272
|
+
tuple(self.cta_order) if self.cta_order else None))
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class SharedLayout:
|
|
276
|
+
"""
|
|
277
|
+
Base class for shared memory layouts in Gluon IR.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def type(self):
|
|
282
|
+
return constexpr_type(self)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@constexpr_function
|
|
286
|
+
def _get_shape_per_cta(shape, cta_split_num):
|
|
287
|
+
shape_per_cta = shape
|
|
288
|
+
if cta_split_num is not None:
|
|
289
|
+
assert len(cta_split_num) == len(shape)
|
|
290
|
+
for dim in range(len(shape_per_cta)):
|
|
291
|
+
shape_per_cta[dim] /= cta_split_num[dim]
|
|
292
|
+
return shape_per_cta
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@dataclass(frozen=True)
|
|
296
|
+
class NVMMASharedLayout(SharedLayout):
|
|
297
|
+
"""
|
|
298
|
+
Represents a layout for shared memory suitable for NVIDIA MMA operations.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
swizzle_byte_width (int): Width in bytes for swizzling.
|
|
302
|
+
element_bitwidth (int): Bitwidth of element type.
|
|
303
|
+
rank (int): Rank of the tensor.
|
|
304
|
+
transposed (bool): Whether the layout is transposed.
|
|
305
|
+
fp4_padded (bool): Whether FP4 padding is used.
|
|
306
|
+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
|
|
307
|
+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
|
|
308
|
+
cta_order (Optional[List[int]]): CTA ordering.
|
|
309
|
+
"""
|
|
310
|
+
swizzle_byte_width: int
|
|
311
|
+
element_bitwidth: int
|
|
312
|
+
rank: int
|
|
313
|
+
transposed: bool = False
|
|
314
|
+
fp4_padded: bool = False
|
|
315
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
316
|
+
cta_split_num: Optional[List[int]] = None
|
|
317
|
+
cta_order: Optional[List[int]] = None
|
|
318
|
+
|
|
319
|
+
def __post_init__(self):
|
|
320
|
+
super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width))
|
|
321
|
+
super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth))
|
|
322
|
+
super().__setattr__("rank", _unwrap_if_constexpr(self.rank))
|
|
323
|
+
super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
|
|
324
|
+
super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded))
|
|
325
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
326
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
327
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
328
|
+
|
|
329
|
+
assert self.element_bitwidth in [8, 16, 32, 64]
|
|
330
|
+
assert self.swizzle_byte_width in [0, 32, 64, 128]
|
|
331
|
+
rank = self.rank
|
|
332
|
+
_realize_cta_layout(self, rank)
|
|
333
|
+
assert len(self.ctas_per_cga) == rank
|
|
334
|
+
assert len(self.cta_split_num) == rank
|
|
335
|
+
assert len(self.cta_order) == rank
|
|
336
|
+
|
|
337
|
+
def _to_ir(self, builder):
|
|
338
|
+
return builder.get_nvmma_shared_layout(
|
|
339
|
+
self.swizzle_byte_width,
|
|
340
|
+
self.element_bitwidth,
|
|
341
|
+
self.transposed,
|
|
342
|
+
self.fp4_padded,
|
|
343
|
+
self.ctas_per_cga,
|
|
344
|
+
self.cta_split_num,
|
|
345
|
+
self.cta_order,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
@constexpr_function
|
|
350
|
+
def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, ctas_per_cga=None, cta_split_num=None,
|
|
351
|
+
cta_order=None):
|
|
352
|
+
"""Returns an NVMMASharedLayout with default swizzling for a given shape.
|
|
353
|
+
|
|
354
|
+
This picks the largest swizzle pattern compatible with the shape, which
|
|
355
|
+
allows emitting the fewest TMA or MMA messages.
|
|
356
|
+
"""
|
|
357
|
+
packing_factor = 2 if fp4_padded else 1
|
|
358
|
+
shape_per_cta = _get_shape_per_cta(block_shape, cta_split_num)
|
|
359
|
+
rank = len(block_shape)
|
|
360
|
+
if transposed:
|
|
361
|
+
shape_per_cta = shape_per_cta[1:] + shape_per_cta[:1]
|
|
362
|
+
contig_dim_size = shape_per_cta[-1] * packing_factor
|
|
363
|
+
contig_dim_bytes = contig_dim_size * dtype.primitive_bitwidth // 8
|
|
364
|
+
if contig_dim_bytes >= 128 and contig_dim_bytes % 128 == 0:
|
|
365
|
+
swizzle_byte_width = 128
|
|
366
|
+
elif contig_dim_bytes >= 64 and contig_dim_bytes % 64 == 0:
|
|
367
|
+
swizzle_byte_width = 64
|
|
368
|
+
elif contig_dim_bytes >= 32 and contig_dim_bytes % 32 == 0:
|
|
369
|
+
swizzle_byte_width = 32
|
|
370
|
+
else:
|
|
371
|
+
swizzle_byte_width = 0
|
|
372
|
+
|
|
373
|
+
flatten_outer_dim = 1
|
|
374
|
+
for size in shape_per_cta[:-1]:
|
|
375
|
+
flatten_outer_dim *= size
|
|
376
|
+
if len(block_shape) < 2 or flatten_outer_dim < 8:
|
|
377
|
+
swizzle_byte_width = 0
|
|
378
|
+
|
|
379
|
+
return NVMMASharedLayout(
|
|
380
|
+
swizzle_byte_width=swizzle_byte_width,
|
|
381
|
+
element_bitwidth=dtype.primitive_bitwidth,
|
|
382
|
+
rank=rank,
|
|
383
|
+
transposed=transposed,
|
|
384
|
+
fp4_padded=fp4_padded,
|
|
385
|
+
ctas_per_cga=ctas_per_cga,
|
|
386
|
+
cta_split_num=cta_split_num,
|
|
387
|
+
cta_order=cta_order,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
def mangle(self) -> str:
|
|
391
|
+
return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
|
|
392
|
+
|
|
393
|
+
def __hash__(self):
|
|
394
|
+
return hash((self.swizzle_byte_width, self.element_bitwidth, self.rank, self.transposed, self.fp4_padded,
|
|
395
|
+
tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
|
|
396
|
+
tuple(self.cta_split_num) if self.cta_split_num else None,
|
|
397
|
+
tuple(self.cta_order) if self.cta_order else None))
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
@dataclass(frozen=True, eq=True)
|
|
401
|
+
class SwizzledSharedLayout(SharedLayout):
|
|
402
|
+
"""
|
|
403
|
+
Represents a generic swizzled shared memory layout.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
vec (int): Vector width for swizzling.
|
|
407
|
+
per_phase (int): Elements per swizzle phase.
|
|
408
|
+
max_phase (int): Maximum number of swizzle phases.
|
|
409
|
+
order (List[int]): Dimension ordering for swizzling.
|
|
410
|
+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
|
|
411
|
+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
|
|
412
|
+
cta_order (Optional[List[int]]): CTA ordering.
|
|
413
|
+
"""
|
|
414
|
+
vec: int
|
|
415
|
+
per_phase: int
|
|
416
|
+
max_phase: int
|
|
417
|
+
order: List[int]
|
|
418
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
419
|
+
cta_split_num: Optional[List[int]] = None
|
|
420
|
+
cta_order: Optional[List[int]] = None
|
|
421
|
+
|
|
422
|
+
def __post_init__(self):
|
|
423
|
+
super().__setattr__("vec", _unwrap_if_constexpr(self.vec))
|
|
424
|
+
super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase))
|
|
425
|
+
super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase))
|
|
426
|
+
super().__setattr__("order", _unwrap_if_constexpr(self.order))
|
|
427
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
428
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
429
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
430
|
+
|
|
431
|
+
rank = len(self.order)
|
|
432
|
+
_realize_cta_layout(self, rank)
|
|
433
|
+
assert len(self.ctas_per_cga) == rank
|
|
434
|
+
assert len(self.cta_split_num) == rank
|
|
435
|
+
assert len(self.cta_order) == rank
|
|
436
|
+
|
|
437
|
+
def _to_ir(self, builder):
|
|
438
|
+
return builder.get_swizzled_shared_layout(
|
|
439
|
+
self.vec,
|
|
440
|
+
self.per_phase,
|
|
441
|
+
self.max_phase,
|
|
442
|
+
self.order,
|
|
443
|
+
self.ctas_per_cga,
|
|
444
|
+
self.cta_split_num,
|
|
445
|
+
self.cta_order,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
def mangle(self) -> str:
|
|
449
|
+
|
|
450
|
+
def stringify(x):
|
|
451
|
+
if x is None:
|
|
452
|
+
return ""
|
|
453
|
+
return "_".join(map(str, x))
|
|
454
|
+
|
|
455
|
+
return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"
|
|
456
|
+
|
|
457
|
+
def __hash__(self):
|
|
458
|
+
return hash((self.vec, self.per_phase, self.max_phase,
|
|
459
|
+
tuple(self.order), tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
|
|
460
|
+
tuple(self.cta_split_num) if self.cta_split_num else None,
|
|
461
|
+
tuple(self.cta_order) if self.cta_order else None))
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass(frozen=True, eq=True)
|
|
465
|
+
class PaddedSharedLayout(SharedLayout):
|
|
466
|
+
"""
|
|
467
|
+
Represents a layout for the access to shared memory. Compared to SwizzledSharedLayout,
|
|
468
|
+
it uses padding to avoid shared memory bank conflicts. After every interval tensor elements,
|
|
469
|
+
the corresponding number of padding elements are inserted.
|
|
470
|
+
If a position corresponds to multiple intervals, the padding amounts are summed.
|
|
471
|
+
|
|
472
|
+
In the following example of a tensor,
|
|
473
|
+
`eM` represents original elements in the and `pN` represents padded element.
|
|
474
|
+
|
|
475
|
+
Before padding, the shared memory looks like:
|
|
476
|
+
[e0, e1,
|
|
477
|
+
e2, e3,
|
|
478
|
+
e4, e5,
|
|
479
|
+
e6, e7,
|
|
480
|
+
...]
|
|
481
|
+
|
|
482
|
+
After padding with interval-padding list [[2, 1], [4, 2]],
|
|
483
|
+
the shared memory will be
|
|
484
|
+
[e0, e1, p0,
|
|
485
|
+
e2, e3, p1, p2, p3,
|
|
486
|
+
e4, e5, p4,
|
|
487
|
+
e6, e7, p5, p6, p7,
|
|
488
|
+
...]
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
interval_padding_pairs (List[int]): List of [interval, padding] pair and both interval and padding must be powers of 2.
|
|
492
|
+
order (List[int]): Order of logical tensor dimensions; fastest-varying first.
|
|
493
|
+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
|
|
494
|
+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
|
|
495
|
+
cta_order (Optional[List[int]]): CTA ordering.
|
|
496
|
+
"""
|
|
497
|
+
interval_padding_pairs: List[List[int]]
|
|
498
|
+
order: List[int]
|
|
499
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
500
|
+
cta_split_num: Optional[List[int]] = None
|
|
501
|
+
cta_order: Optional[List[int]] = None
|
|
502
|
+
|
|
503
|
+
def __post_init__(self):
|
|
504
|
+
super().__setattr__("interval_padding_pairs", _unwrap_shape(self.interval_padding_pairs))
|
|
505
|
+
super().__setattr__("order", _unwrap_if_constexpr(self.order))
|
|
506
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
507
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
508
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
509
|
+
|
|
510
|
+
self.verify()
|
|
511
|
+
|
|
512
|
+
def _to_ir(self, builder):
|
|
513
|
+
intervals, paddings = zip(*self.interval_padding_pairs)
|
|
514
|
+
return builder.get_padded_shared_layout(intervals, paddings, self.order, self.ctas_per_cga, self.cta_split_num,
|
|
515
|
+
self.cta_order)
|
|
516
|
+
|
|
517
|
+
def mangle(self) -> str:
|
|
518
|
+
|
|
519
|
+
def stringify(x):
|
|
520
|
+
if x is None:
|
|
521
|
+
return ""
|
|
522
|
+
return "_".join(map(str, x))
|
|
523
|
+
|
|
524
|
+
return f"PaddedShared_{stringify(self.interval_padding_pairs)}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_PaddedShared"
|
|
525
|
+
|
|
526
|
+
def verify(self):
|
|
527
|
+
pairs = self.interval_padding_pairs
|
|
528
|
+
assert len(pairs) > 0, "PaddedSharedLayout interval_padding_pairs must have at least one interval-padding pair"
|
|
529
|
+
assert all(len(pair) == 2 for pair in pairs)
|
|
530
|
+
intervals, paddings = zip(*pairs)
|
|
531
|
+
|
|
532
|
+
unique_intervals = list(set(intervals))
|
|
533
|
+
assert len(unique_intervals) == len(intervals)
|
|
534
|
+
|
|
535
|
+
is_power_of_2 = lambda n: n > 0 and n & (n - 1) == 0
|
|
536
|
+
assert all(is_power_of_2(n) for n in intervals), "PaddedSharedLayout interval values must all be power of two"
|
|
537
|
+
assert all(is_power_of_2(n) for n in paddings), "PaddedSharedLayout padding values must all be power of two"
|
|
538
|
+
|
|
539
|
+
rank = len(self.order)
|
|
540
|
+
assert rank > 0, "PaddedSharedLayout order must not be empty"
|
|
541
|
+
_realize_cta_layout(self, rank)
|
|
542
|
+
|
|
543
|
+
assert len(self.ctas_per_cga) == rank
|
|
544
|
+
assert len(self.cta_split_num) == rank
|
|
545
|
+
assert len(self.cta_order) == rank
|
|
546
|
+
|
|
547
|
+
def __hash__(self):
|
|
548
|
+
return hash((tuple(map(tuple, self.interval_padding_pairs)),
|
|
549
|
+
tuple(self.order), tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
|
|
550
|
+
tuple(self.cta_split_num) if self.cta_split_num else None,
|
|
551
|
+
tuple(self.cta_order) if self.cta_order else None))
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
# Python impl of LinearEncodingAttr::basesPerDim
|
|
555
|
+
def bases_per_dim(bases, rank, skip_broadcast=True):
|
|
556
|
+
result = [1] * rank
|
|
557
|
+
|
|
558
|
+
if not bases:
|
|
559
|
+
return result
|
|
560
|
+
|
|
561
|
+
non_zero_idx = None
|
|
562
|
+
|
|
563
|
+
for basis in bases:
|
|
564
|
+
# Find the first non-zero index in the current basis
|
|
565
|
+
idx = next((i for i, v in enumerate(basis) if v != 0), None)
|
|
566
|
+
if idx is not None:
|
|
567
|
+
non_zero_idx = idx
|
|
568
|
+
result[idx] *= 2
|
|
569
|
+
elif not skip_broadcast:
|
|
570
|
+
# If no non-zero found and we're not skipping broadcasts, use the last found non-zero index
|
|
571
|
+
assert non_zero_idx is not None
|
|
572
|
+
result[non_zero_idx] *= 2
|
|
573
|
+
|
|
574
|
+
return result
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def warps_per_cta(layout, shape):
|
|
578
|
+
if isinstance(layout, DistributedLinearLayout):
|
|
579
|
+
return bases_per_dim(layout.warp_bases, len(shape))
|
|
580
|
+
elif isinstance(layout, (SliceLayout, DotOperandLayout)):
|
|
581
|
+
return warps_per_cta(layout.parent, shape)
|
|
582
|
+
else:
|
|
583
|
+
return layout.warps_per_cta
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import triton.language.math as tl_math
|
|
2
|
+
from ._core import builtin
|
|
3
|
+
|
|
4
|
+
umulhi = builtin(tl_math.umulhi)
|
|
5
|
+
exp = builtin(tl_math.exp)
|
|
6
|
+
exp2 = builtin(tl_math.exp2)
|
|
7
|
+
fma = builtin(tl_math.fma)
|
|
8
|
+
log = builtin(tl_math.log)
|
|
9
|
+
log2 = builtin(tl_math.log2)
|
|
10
|
+
cos = builtin(tl_math.cos)
|
|
11
|
+
rsqrt = builtin(tl_math.rsqrt)
|
|
12
|
+
sin = builtin(tl_math.sin)
|
|
13
|
+
sqrt = builtin(tl_math.sqrt)
|
|
14
|
+
sqrt_rn = builtin(tl_math.sqrt_rn)
|
|
15
|
+
abs = builtin(tl_math.abs)
|
|
16
|
+
fdiv = builtin(tl_math.fdiv)
|
|
17
|
+
div_rn = builtin(tl_math.div_rn)
|
|
18
|
+
erf = builtin(tl_math.erf)
|
|
19
|
+
floor = builtin(tl_math.floor)
|
|
20
|
+
ceil = builtin(tl_math.ceil)
|