triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.4.0.post20__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TypeVar, List, TYPE_CHECKING, Tuple
|
|
3
|
+
from functools import wraps
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from triton._C.libtriton.gluon_ir import GluonOpBuilder
|
|
7
|
+
from ._semantic import GluonSemantic
|
|
8
|
+
|
|
9
|
+
from ._layouts import SharedLayout, DistributedLayout
|
|
10
|
+
from triton._C.libtriton import ir
|
|
11
|
+
import triton.language.core as tl_core
|
|
12
|
+
from triton.language.core import (
|
|
13
|
+
constexpr,
|
|
14
|
+
base_value,
|
|
15
|
+
base_type,
|
|
16
|
+
dtype,
|
|
17
|
+
block_type, # TODO: block type with layout info
|
|
18
|
+
pointer_type,
|
|
19
|
+
void,
|
|
20
|
+
int1,
|
|
21
|
+
int8,
|
|
22
|
+
int16,
|
|
23
|
+
int32,
|
|
24
|
+
int64,
|
|
25
|
+
uint8,
|
|
26
|
+
uint16,
|
|
27
|
+
uint32,
|
|
28
|
+
uint64,
|
|
29
|
+
float8e5,
|
|
30
|
+
float8e5b16,
|
|
31
|
+
float8e4nv,
|
|
32
|
+
float8e4b8,
|
|
33
|
+
float8e4b15,
|
|
34
|
+
float16,
|
|
35
|
+
bfloat16,
|
|
36
|
+
float32,
|
|
37
|
+
float64,
|
|
38
|
+
_unwrap_if_constexpr,
|
|
39
|
+
_unwrap_shape,
|
|
40
|
+
tensor,
|
|
41
|
+
tuple,
|
|
42
|
+
tuple_type,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
_IMPORT_FROM_TRITON: List[str] = [
|
|
46
|
+
"expand_dims",
|
|
47
|
+
"join",
|
|
48
|
+
"load",
|
|
49
|
+
"maximum",
|
|
50
|
+
"minimum",
|
|
51
|
+
"permute",
|
|
52
|
+
"program_id",
|
|
53
|
+
"reduce",
|
|
54
|
+
"reshape",
|
|
55
|
+
"split",
|
|
56
|
+
"static_assert",
|
|
57
|
+
"static_print",
|
|
58
|
+
"store",
|
|
59
|
+
"to_tensor",
|
|
60
|
+
"where",
|
|
61
|
+
"inline_asm_elementwise",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
__all__ = [
|
|
65
|
+
"constexpr",
|
|
66
|
+
"base_value",
|
|
67
|
+
"base_type",
|
|
68
|
+
"dtype",
|
|
69
|
+
"block_type",
|
|
70
|
+
"pointer_type",
|
|
71
|
+
"tuple_type",
|
|
72
|
+
"void",
|
|
73
|
+
"int1",
|
|
74
|
+
"int8",
|
|
75
|
+
"int16",
|
|
76
|
+
"int32",
|
|
77
|
+
"int64",
|
|
78
|
+
"uint8",
|
|
79
|
+
"uint16",
|
|
80
|
+
"uint32",
|
|
81
|
+
"uint64",
|
|
82
|
+
"float8e5",
|
|
83
|
+
"float8e5b16",
|
|
84
|
+
"float8e4nv",
|
|
85
|
+
"float8e4b8",
|
|
86
|
+
"float8e4b8",
|
|
87
|
+
"float8e4b15",
|
|
88
|
+
"float16",
|
|
89
|
+
"bfloat16",
|
|
90
|
+
"float32",
|
|
91
|
+
"float64",
|
|
92
|
+
"_unwrap_if_constexpr",
|
|
93
|
+
"tensor",
|
|
94
|
+
"tuple",
|
|
95
|
+
"tuple_type",
|
|
96
|
+
"thread_barrier",
|
|
97
|
+
"arange",
|
|
98
|
+
"full",
|
|
99
|
+
"convert_layout",
|
|
100
|
+
"allocate_shared_memory",
|
|
101
|
+
"shared_memory_descriptor",
|
|
102
|
+
"warp_specialize",
|
|
103
|
+
*_IMPORT_FROM_TRITON,
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
T = TypeVar("T")
|
|
107
|
+
|
|
108
|
+
# TODO: split these
|
|
109
|
+
GLUON_BUILTIN = "__triton_builtin__"
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class distributed_type(block_type):
|
|
113
|
+
|
|
114
|
+
def __init__(self, element_ty: dtype, shape: List[int], layout):
|
|
115
|
+
super().__init__(element_ty, shape)
|
|
116
|
+
self.layout = layout
|
|
117
|
+
self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>"
|
|
118
|
+
assert isinstance(layout, DistributedLayout)
|
|
119
|
+
|
|
120
|
+
def to_ir(self, builder: ir.builder) -> ir.type:
|
|
121
|
+
elem_ty = self.element_ty.to_ir(builder)
|
|
122
|
+
layout = self.layout._to_ir(builder)
|
|
123
|
+
return builder.get_distributed_ty(elem_ty, self.shape, layout)
|
|
124
|
+
|
|
125
|
+
def mangle(self) -> str:
|
|
126
|
+
elt = self.scalar.mangle()
|
|
127
|
+
shape = "_".join(map(str, self.shape))
|
|
128
|
+
layout = self.layout.mangle()
|
|
129
|
+
return f"{elt}S{shape}SL{layout}L"
|
|
130
|
+
|
|
131
|
+
def with_element_ty(self, scalar_ty: dtype) -> block_type:
|
|
132
|
+
return distributed_type(scalar_ty, self.shape, self.layout)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def builtin(fn: T) -> T:
|
|
136
|
+
"""Mark a function as a builtin."""
|
|
137
|
+
assert callable(fn)
|
|
138
|
+
|
|
139
|
+
@wraps(fn)
|
|
140
|
+
def wrapper(*args, **kwargs):
|
|
141
|
+
if "_semantic" not in kwargs or kwargs["_semantic"] is None:
|
|
142
|
+
raise ValueError("Did you forget to add @triton.gluon.jit ? "
|
|
143
|
+
"(`_semantic` argument must be provided outside of JIT functions.)")
|
|
144
|
+
return fn(*args, **kwargs)
|
|
145
|
+
|
|
146
|
+
setattr(wrapper, GLUON_BUILTIN, True)
|
|
147
|
+
|
|
148
|
+
return wrapper
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class shared_memory_descriptor_type(base_type):
|
|
152
|
+
|
|
153
|
+
def __init__(self, element_ty, shape, layout, alloc_shape):
|
|
154
|
+
self.element_ty = element_ty
|
|
155
|
+
self.shape = shape
|
|
156
|
+
self.layout = layout
|
|
157
|
+
self.alloc_shape = alloc_shape
|
|
158
|
+
assert isinstance(layout, SharedLayout)
|
|
159
|
+
|
|
160
|
+
def to_ir(self, builder: GluonOpBuilder) -> None:
|
|
161
|
+
return builder.get_shared_mem_desc_ty(
|
|
162
|
+
self.element_ty.to_ir(builder),
|
|
163
|
+
self.shape,
|
|
164
|
+
self.layout._to_ir(builder),
|
|
165
|
+
self.alloc_shape,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]:
|
|
169
|
+
value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
|
|
170
|
+
return value, cursor + 1
|
|
171
|
+
|
|
172
|
+
def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
|
|
173
|
+
out.append(self.to_ir(builder))
|
|
174
|
+
|
|
175
|
+
def __str__(self) -> str:
|
|
176
|
+
return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
|
|
177
|
+
|
|
178
|
+
def __eq__(self, other) -> bool:
|
|
179
|
+
return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
|
|
180
|
+
and self.alloc_shape == other.alloc_shape)
|
|
181
|
+
|
|
182
|
+
def __neq__(self, other) -> bool:
|
|
183
|
+
return not (self == other)
|
|
184
|
+
|
|
185
|
+
def mangle(self) -> str:
|
|
186
|
+
shape_str = "_".join([str(s) for s in self.shape])
|
|
187
|
+
return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class shared_memory_descriptor(base_value):
|
|
191
|
+
|
|
192
|
+
def __init__(self, handle, element_ty, shape, layout, alloc_shape):
|
|
193
|
+
self.handle = handle
|
|
194
|
+
self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
|
|
195
|
+
|
|
196
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
197
|
+
handles.append(self.handle)
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def dtype(self):
|
|
201
|
+
return self.type.element_ty
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def shape(self):
|
|
205
|
+
return self.type.shape
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def rank(self):
|
|
209
|
+
return len(self.shape)
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def layout(self):
|
|
213
|
+
return self.type.layout
|
|
214
|
+
|
|
215
|
+
def __str__(self) -> str:
|
|
216
|
+
return str(self.type)
|
|
217
|
+
|
|
218
|
+
@builtin
|
|
219
|
+
def load(self, layout, _semantic: GluonSemantic) -> tensor:
|
|
220
|
+
layout = _unwrap_if_constexpr(layout)
|
|
221
|
+
return _semantic.shared_load(self, layout)
|
|
222
|
+
|
|
223
|
+
@builtin
|
|
224
|
+
def store(self, value, _semantic: GluonSemantic) -> None:
|
|
225
|
+
return _semantic.shared_store(self, value)
|
|
226
|
+
|
|
227
|
+
@builtin
|
|
228
|
+
def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
229
|
+
start = _unwrap_if_constexpr(start)
|
|
230
|
+
length = _unwrap_if_constexpr(length)
|
|
231
|
+
dim = _unwrap_if_constexpr(dim)
|
|
232
|
+
return _semantic.memdesc_slice(self, start, length, dim)
|
|
233
|
+
|
|
234
|
+
@builtin
|
|
235
|
+
def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
236
|
+
index = _unwrap_if_constexpr(index)
|
|
237
|
+
return _semantic.memdesc_index(self, index)
|
|
238
|
+
|
|
239
|
+
@builtin
|
|
240
|
+
def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
|
|
241
|
+
order = [_unwrap_if_constexpr(o) for o in order]
|
|
242
|
+
return _semantic.memdesc_trans(self, order)
|
|
243
|
+
|
|
244
|
+
@builtin
|
|
245
|
+
def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
|
|
246
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
247
|
+
layout = _unwrap_if_constexpr(layout)
|
|
248
|
+
|
|
249
|
+
return _semantic.memdesc_reshape(self, shape, layout)
|
|
250
|
+
|
|
251
|
+
@builtin
|
|
252
|
+
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
253
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
254
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
255
|
+
layout = _unwrap_if_constexpr(layout)
|
|
256
|
+
|
|
257
|
+
return _semantic.memdesc_reinterpret(self, dtype, shape, layout)
|
|
258
|
+
|
|
259
|
+
@builtin
|
|
260
|
+
def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
|
|
261
|
+
return _semantic.shared_dealloc(self)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
for name in _IMPORT_FROM_TRITON:
|
|
265
|
+
fn = getattr(tl_core, name)
|
|
266
|
+
globals()[name] = builtin(fn)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@builtin
|
|
270
|
+
def arange(start, end, layout, _semantic=None):
|
|
271
|
+
start = _unwrap_if_constexpr(start)
|
|
272
|
+
end = _unwrap_if_constexpr(end)
|
|
273
|
+
layout = _unwrap_if_constexpr(layout)
|
|
274
|
+
return _semantic.arange(start, end, layout)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@builtin
|
|
278
|
+
def convert_layout(value, layout, _semantic=None):
|
|
279
|
+
layout = _unwrap_if_constexpr(layout)
|
|
280
|
+
return _semantic.convert_layout(value, layout)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@builtin
|
|
284
|
+
def full(shape, value, dtype, layout, _semantic=None):
|
|
285
|
+
shape = _unwrap_shape(shape)
|
|
286
|
+
value = _unwrap_if_constexpr(value)
|
|
287
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
288
|
+
layout = _unwrap_if_constexpr(layout)
|
|
289
|
+
return _semantic.full(shape, value, dtype, layout)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@builtin
|
|
293
|
+
def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None):
|
|
294
|
+
element_ty = _unwrap_if_constexpr(element_ty)
|
|
295
|
+
shape = _unwrap_if_constexpr(shape)
|
|
296
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
297
|
+
layout = _unwrap_if_constexpr(layout)
|
|
298
|
+
return _semantic.allocate_shared(element_ty, shape, layout, value)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@builtin
|
|
302
|
+
def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
|
|
303
|
+
_semantic=None, _generator=None):
|
|
304
|
+
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
|
|
305
|
+
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
|
|
306
|
+
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
|
|
307
|
+
worker_num_regs, _generator)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@builtin
|
|
311
|
+
def thread_barrier(_semantic=None):
|
|
312
|
+
return _semantic.debug_barrier()
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
from triton.language.core import _unwrap_if_constexpr, _unwrap_shape
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"BlockedLayout",
|
|
7
|
+
"SliceLayout",
|
|
8
|
+
"DistributedLinearLayout",
|
|
9
|
+
"NVMMASharedLayout",
|
|
10
|
+
"SwizzledSharedLayout",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order):
|
|
15
|
+
ctas_per_cga = ctas_per_cga or [1] * rank
|
|
16
|
+
cta_split_num = cta_split_num or [1] * rank
|
|
17
|
+
cta_order = cta_order or list(reversed(range(rank)))
|
|
18
|
+
return ctas_per_cga, cta_split_num, cta_order
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DistributedLayout:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class BlockedLayout(DistributedLayout):
|
|
27
|
+
size_per_thread: List[int]
|
|
28
|
+
threads_per_warp: List[int]
|
|
29
|
+
warps_per_cta: List[int]
|
|
30
|
+
order: List[int]
|
|
31
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
32
|
+
cta_split_num: Optional[List[int]] = None
|
|
33
|
+
cta_order: Optional[List[int]] = None
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread))
|
|
37
|
+
super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp))
|
|
38
|
+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
|
|
39
|
+
super().__setattr__("order", _unwrap_if_constexpr(self.order))
|
|
40
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
41
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
42
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
43
|
+
|
|
44
|
+
rank = len(self.size_per_thread)
|
|
45
|
+
assert len(self.threads_per_warp) == rank
|
|
46
|
+
assert len(self.warps_per_cta) == rank
|
|
47
|
+
assert len(self.order) == rank
|
|
48
|
+
assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
|
|
49
|
+
assert self.cta_split_num is None or len(self.cta_split_num) == rank
|
|
50
|
+
assert self.cta_order is None or len(self.cta_order) == rank
|
|
51
|
+
|
|
52
|
+
def _to_ir(self, builder):
|
|
53
|
+
rank = len(self.size_per_thread)
|
|
54
|
+
ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
|
|
55
|
+
self.cta_order)
|
|
56
|
+
return builder.get_blocked_layout(
|
|
57
|
+
self.size_per_thread,
|
|
58
|
+
self.threads_per_warp,
|
|
59
|
+
self.warps_per_cta,
|
|
60
|
+
self.order,
|
|
61
|
+
ctas_per_cga,
|
|
62
|
+
cta_split_num,
|
|
63
|
+
cta_order,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def mangle(self) -> str:
|
|
67
|
+
|
|
68
|
+
def stringify(x):
|
|
69
|
+
if x is None:
|
|
70
|
+
return ""
|
|
71
|
+
return "_".join(map(str, x))
|
|
72
|
+
|
|
73
|
+
size_per_thread = stringify(self.size_per_thread)
|
|
74
|
+
threads_per_warp = stringify(self.threads_per_warp)
|
|
75
|
+
warps_per_cta = stringify(self.warps_per_cta)
|
|
76
|
+
order = stringify(self.order)
|
|
77
|
+
ctas_per_cga = stringify(self.ctas_per_cga)
|
|
78
|
+
cta_split_num = stringify(self.cta_split_num)
|
|
79
|
+
cta_order = stringify(self.cta_order)
|
|
80
|
+
return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass(frozen=True)
|
|
84
|
+
class SliceLayout(DistributedLayout):
|
|
85
|
+
dim: int
|
|
86
|
+
parent: DistributedLayout
|
|
87
|
+
|
|
88
|
+
def __post_init__(self):
|
|
89
|
+
super().__setattr__("dim", _unwrap_if_constexpr(self.dim))
|
|
90
|
+
super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
|
|
91
|
+
|
|
92
|
+
def _to_ir(self, builder):
|
|
93
|
+
return builder.get_slice_layout(
|
|
94
|
+
self.dim,
|
|
95
|
+
self.parent._to_ir(builder),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def mangle(self) -> str:
|
|
99
|
+
return f"SL{self.dim}_{self.parent.mangle()}SL"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass(frozen=True)
|
|
103
|
+
class DistributedLinearLayout(DistributedLayout):
|
|
104
|
+
reg_bases: List[List[int]]
|
|
105
|
+
lane_bases: List[List[int]]
|
|
106
|
+
warp_bases: List[List[int]]
|
|
107
|
+
block_bases: List[List[int]]
|
|
108
|
+
shape: List[int]
|
|
109
|
+
|
|
110
|
+
def __post_init__(self):
|
|
111
|
+
super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases))
|
|
112
|
+
super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases))
|
|
113
|
+
super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases))
|
|
114
|
+
super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
|
|
115
|
+
super().__setattr__("shape", _unwrap_shape(self.shape))
|
|
116
|
+
|
|
117
|
+
rank = len(self.shape)
|
|
118
|
+
|
|
119
|
+
for basis in self.reg_bases:
|
|
120
|
+
assert len(basis) == rank
|
|
121
|
+
for basis in self.lane_bases:
|
|
122
|
+
assert len(basis) == rank
|
|
123
|
+
for basis in self.warp_bases:
|
|
124
|
+
assert len(basis) == rank
|
|
125
|
+
for basis in self.block_bases:
|
|
126
|
+
assert len(basis) == rank
|
|
127
|
+
|
|
128
|
+
def _to_ir(self, builder):
|
|
129
|
+
return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases,
|
|
130
|
+
self.shape)
|
|
131
|
+
|
|
132
|
+
def mangle(self):
|
|
133
|
+
return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class SharedLayout:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass(frozen=True)
|
|
141
|
+
class NVMMASharedLayout(SharedLayout):
|
|
142
|
+
swizzle_byte_width: int
|
|
143
|
+
element_bitwidth: int
|
|
144
|
+
rank: int
|
|
145
|
+
transposed: bool = False
|
|
146
|
+
fp4_padded: bool = False
|
|
147
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
148
|
+
cta_split_num: Optional[List[int]] = None
|
|
149
|
+
cta_order: Optional[List[int]] = None
|
|
150
|
+
|
|
151
|
+
def __post_init__(self):
|
|
152
|
+
super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width))
|
|
153
|
+
super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth))
|
|
154
|
+
super().__setattr__("rank", _unwrap_if_constexpr(self.rank))
|
|
155
|
+
super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
|
|
156
|
+
super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded))
|
|
157
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
158
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
159
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
160
|
+
|
|
161
|
+
assert self.element_bitwidth in [8, 16, 32, 64]
|
|
162
|
+
assert self.swizzle_byte_width in [0, 32, 64, 128]
|
|
163
|
+
rank = self.rank
|
|
164
|
+
assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
|
|
165
|
+
assert self.cta_split_num is None or len(self.cta_split_num) == rank
|
|
166
|
+
assert self.cta_order is None or len(self.cta_order) == rank
|
|
167
|
+
|
|
168
|
+
def _to_ir(self, builder):
|
|
169
|
+
ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(self.rank, self.ctas_per_cga, self.cta_split_num,
|
|
170
|
+
self.cta_order)
|
|
171
|
+
return builder.get_nvmma_shared_layout(
|
|
172
|
+
self.swizzle_byte_width,
|
|
173
|
+
self.element_bitwidth,
|
|
174
|
+
self.transposed,
|
|
175
|
+
self.fp4_padded,
|
|
176
|
+
ctas_per_cga,
|
|
177
|
+
cta_split_num,
|
|
178
|
+
cta_order,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def mangle(self) -> str:
|
|
182
|
+
return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@dataclass(frozen=True, eq=True)
|
|
186
|
+
class SwizzledSharedLayout(SharedLayout):
|
|
187
|
+
vec: int
|
|
188
|
+
per_phase: int
|
|
189
|
+
max_phase: int
|
|
190
|
+
order: List[int]
|
|
191
|
+
ctas_per_cga: Optional[List[int]] = None
|
|
192
|
+
cta_split_num: Optional[List[int]] = None
|
|
193
|
+
cta_order: Optional[List[int]] = None
|
|
194
|
+
|
|
195
|
+
def __post_init__(self):
|
|
196
|
+
super().__setattr__("vec", _unwrap_if_constexpr(self.vec))
|
|
197
|
+
super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase))
|
|
198
|
+
super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase))
|
|
199
|
+
super().__setattr__("order", _unwrap_if_constexpr(self.order))
|
|
200
|
+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
|
|
201
|
+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
|
|
202
|
+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
|
|
203
|
+
|
|
204
|
+
rank = len(self.order)
|
|
205
|
+
assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
|
|
206
|
+
assert self.cta_split_num is None or len(self.cta_split_num) == rank
|
|
207
|
+
assert self.cta_order is None or len(self.cta_order) == rank
|
|
208
|
+
|
|
209
|
+
def _to_ir(self, builder):
|
|
210
|
+
rank = len(self.order)
|
|
211
|
+
ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
|
|
212
|
+
self.cta_order)
|
|
213
|
+
return builder.get_swizzled_shared_layout(
|
|
214
|
+
_unwrap_if_constexpr(self.vec),
|
|
215
|
+
_unwrap_if_constexpr(self.per_phase),
|
|
216
|
+
_unwrap_if_constexpr(self.max_phase),
|
|
217
|
+
self.order,
|
|
218
|
+
ctas_per_cga,
|
|
219
|
+
cta_split_num,
|
|
220
|
+
cta_order,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def mangle(self) -> str:
|
|
224
|
+
|
|
225
|
+
def stringify(x):
|
|
226
|
+
if x is None:
|
|
227
|
+
return ""
|
|
228
|
+
return "_".join(map(str, x))
|
|
229
|
+
|
|
230
|
+
return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import triton.language.math as tl_math
|
|
3
|
+
from ._core import builtin
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"umulhi", "exp", "exp2", "fma", "log", "log2", "cos", "rsqrt", "sin", "sqrt", "sqrt_rn", "abs", "fdiv", "div_rn",
|
|
7
|
+
"erf", "floor", "ceil"
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
for name in __all__:
|
|
11
|
+
fn = getattr(tl_math, name)
|
|
12
|
+
globals()[name] = builtin(fn)
|