triton-windows 3.3.0.post19__cp312-cp312-win_amd64.whl → 3.4.0.post20__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
- triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
from typing import Sequence, List, TypeVar, Tuple, Callable
|
|
2
|
+
from triton.language.semantic import TritonSemantic
|
|
3
|
+
from . import _core as ttgl
|
|
4
|
+
from ._layouts import SliceLayout
|
|
5
|
+
from triton._C.libtriton.gluon_ir import GluonOpBuilder
|
|
6
|
+
from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
|
|
7
|
+
|
|
8
|
+
TensorTy = TypeVar("TensorTy")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
|
|
12
|
+
if not cond:
|
|
13
|
+
raise category(msg_fn())
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GluonSemantic(TritonSemantic[TensorTy]):
|
|
17
|
+
tensor = ttgl.tensor
|
|
18
|
+
lang = ttgl
|
|
19
|
+
|
|
20
|
+
builder: GluonOpBuilder
|
|
21
|
+
|
|
22
|
+
def __init__(self, builder: GluonOpBuilder):
|
|
23
|
+
self.builder = builder
|
|
24
|
+
|
|
25
|
+
def _wrap_tensor_infer_layout(self, tensor):
|
|
26
|
+
ty = ttgl.distributed_type(tensor.type.scalar, tensor.shape,
|
|
27
|
+
self.builder.get_gluon_layout_from_tensor(tensor.handle))
|
|
28
|
+
return self.tensor(tensor.handle, ty)
|
|
29
|
+
|
|
30
|
+
def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
|
|
31
|
+
if len(lhs_shape) != len(rhs_shape):
|
|
32
|
+
raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")
|
|
33
|
+
|
|
34
|
+
ret_shape = []
|
|
35
|
+
for i, left in enumerate(lhs_shape):
|
|
36
|
+
right = rhs_shape[i]
|
|
37
|
+
if left == 1:
|
|
38
|
+
ret_shape.append(right)
|
|
39
|
+
elif (right == 1) or (right == left):
|
|
40
|
+
ret_shape.append(left)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
|
|
43
|
+
"at index " + str(i) + ": " + str(left) + " and " + str(right))
|
|
44
|
+
return ret_shape
|
|
45
|
+
|
|
46
|
+
def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
|
|
47
|
+
dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape]
|
|
48
|
+
dst_shape.insert(axis, 1)
|
|
49
|
+
|
|
50
|
+
if axis < 0:
|
|
51
|
+
axis += len(input.shape)
|
|
52
|
+
|
|
53
|
+
_check(isinstance(input.type, ttgl.distributed_type),
|
|
54
|
+
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
|
|
55
|
+
layout = input.type.layout
|
|
56
|
+
_check(isinstance(layout, SliceLayout),
|
|
57
|
+
lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
|
|
58
|
+
_check(layout.dim == axis,
|
|
59
|
+
lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")
|
|
60
|
+
|
|
61
|
+
ret_ty = ttgl.distributed_type(input.type.scalar, dst_shape, layout.parent)
|
|
62
|
+
handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder))
|
|
63
|
+
return self.tensor(handle, ret_ty)
|
|
64
|
+
|
|
65
|
+
def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
|
|
66
|
+
a, b = self.broadcast_impl_value(a, b)
|
|
67
|
+
_check(a.shape != [], "Cannot join scalars in gluon")
|
|
68
|
+
value = super().join(a, b)
|
|
69
|
+
return self._wrap_tensor_infer_layout(value)
|
|
70
|
+
|
|
71
|
+
def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
|
|
72
|
+
lhs, rhs = super().split(a)
|
|
73
|
+
return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs)
|
|
74
|
+
|
|
75
|
+
def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
|
|
76
|
+
value = super().permute(input, dims)
|
|
77
|
+
return self._wrap_tensor_infer_layout(value)
|
|
78
|
+
|
|
79
|
+
def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
|
|
80
|
+
_check(isinstance(input.type, ttgl.distributed_type),
|
|
81
|
+
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
|
|
82
|
+
src_shape = input.type.get_block_shapes()
|
|
83
|
+
_check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
|
|
84
|
+
if shape == src_shape:
|
|
85
|
+
return input
|
|
86
|
+
for i, item in enumerate(src_shape):
|
|
87
|
+
if shape[i] != item and item != 1:
|
|
88
|
+
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
|
89
|
+
f" must match the existing size ({item}) at non-singleton dimension"
|
|
90
|
+
f" {i}: {src_shape}, {shape}")
|
|
91
|
+
ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout)
|
|
92
|
+
handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder))
|
|
93
|
+
return self.tensor(handle, ret_ty)
|
|
94
|
+
|
|
95
|
+
def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
|
|
96
|
+
lhs_ty = lhs.type
|
|
97
|
+
rhs_ty = rhs.type
|
|
98
|
+
|
|
99
|
+
if not lhs_ty.is_block() or not rhs_ty.is_block():
|
|
100
|
+
return super().broadcast_impl_value(lhs, rhs)
|
|
101
|
+
|
|
102
|
+
_check(isinstance(lhs_ty, ttgl.distributed_type),
|
|
103
|
+
lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}")
|
|
104
|
+
_check(isinstance(rhs_ty, ttgl.distributed_type),
|
|
105
|
+
lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}")
|
|
106
|
+
|
|
107
|
+
lhs_shape = lhs_ty.get_block_shapes()
|
|
108
|
+
rhs_shape = rhs_ty.get_block_shapes()
|
|
109
|
+
ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
|
|
110
|
+
if lhs_ty.layout != rhs_ty.layout:
|
|
111
|
+
raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
|
|
112
|
+
|
|
113
|
+
lhs = self.broadcast_impl_shape(lhs, ret_shape)
|
|
114
|
+
rhs = self.broadcast_impl_shape(rhs, ret_shape)
|
|
115
|
+
return lhs, rhs
|
|
116
|
+
|
|
117
|
+
def arange(self, start, end, layout):
|
|
118
|
+
shape = [end - start]
|
|
119
|
+
ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
|
|
120
|
+
return super().arange(start, end, ret_ty=ret_ty)
|
|
121
|
+
|
|
122
|
+
def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
|
|
123
|
+
_check(not can_reorder, "can_reorder is not supported in gluon")
|
|
124
|
+
value = super().reshape(input, dst_shape, can_reorder)
|
|
125
|
+
return self._wrap_tensor_infer_layout(value)
|
|
126
|
+
|
|
127
|
+
def splat(self, value, shape, layout):
|
|
128
|
+
ret_ty = ttgl.distributed_type(value.dtype, shape, layout)
|
|
129
|
+
handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle)
|
|
130
|
+
return ttgl.tensor(handle, ret_ty)
|
|
131
|
+
|
|
132
|
+
def full(self, shape, value, dtype, layout):
|
|
133
|
+
scalar = self.make_scalar(value, dtype)
|
|
134
|
+
return self.splat(scalar, shape, layout)
|
|
135
|
+
|
|
136
|
+
def convert_layout(self, value, layout):
|
|
137
|
+
ty = value.type
|
|
138
|
+
_check(isinstance(ty, ttgl.distributed_type),
|
|
139
|
+
lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
|
|
140
|
+
ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
|
|
141
|
+
handle = self.builder.create_convert_layout(ret_ty.to_ir(self.builder), value.handle)
|
|
142
|
+
return ttgl.tensor(handle, ret_ty)
|
|
143
|
+
|
|
144
|
+
def allocate_shared(self, element_ty, shape, layout, value):
|
|
145
|
+
ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape)
|
|
146
|
+
if value is not None:
|
|
147
|
+
handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle)
|
|
148
|
+
else:
|
|
149
|
+
handle = self.builder.create_local_alloc(ty.to_ir(self.builder))
|
|
150
|
+
return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape)
|
|
151
|
+
|
|
152
|
+
def shared_load(self, mem_desc, layout):
|
|
153
|
+
ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout)
|
|
154
|
+
handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle)
|
|
155
|
+
return ttgl.tensor(handle, ret_ty)
|
|
156
|
+
|
|
157
|
+
def shared_store(self, mem_desc, value):
|
|
158
|
+
self.builder.create_local_store(mem_desc.handle, value.handle)
|
|
159
|
+
|
|
160
|
+
def shared_dealloc(self, mem_desc):
|
|
161
|
+
self.builder.create_local_dealloc(mem_desc.handle)
|
|
162
|
+
|
|
163
|
+
def _memdesc_subview(self, mem_desc, offsets, shape):
|
|
164
|
+
layout = mem_desc.layout
|
|
165
|
+
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
|
|
166
|
+
builder = self.builder
|
|
167
|
+
handle = builder.create_memdesc_subview(ty.to_ir(builder), mem_desc.handle, offsets)
|
|
168
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
169
|
+
|
|
170
|
+
def memdesc_slice(self, mem_desc, start, length, dim):
|
|
171
|
+
offsets = [self.builder.get_int32(0)] * mem_desc.rank
|
|
172
|
+
offsets[dim] = self.to_tensor(start).handle
|
|
173
|
+
shape = list(mem_desc.shape)
|
|
174
|
+
shape[dim] = length
|
|
175
|
+
return self._memdesc_subview(mem_desc, offsets, shape)
|
|
176
|
+
|
|
177
|
+
def memdesc_index(self, mem_desc, index):
|
|
178
|
+
shape = mem_desc.shape[1:]
|
|
179
|
+
offsets = [self.builder.get_int32(0)] * mem_desc.rank
|
|
180
|
+
offsets[0] = self.to_tensor(index).handle
|
|
181
|
+
return self._memdesc_subview(mem_desc, offsets, shape)
|
|
182
|
+
|
|
183
|
+
def memdesc_trans(self, mem_desc, order):
|
|
184
|
+
assert len(order) == len(
|
|
185
|
+
mem_desc.shape), f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match"
|
|
186
|
+
|
|
187
|
+
shape = [mem_desc.shape[i] for i in order]
|
|
188
|
+
alloc_shape = mem_desc.type.alloc_shape
|
|
189
|
+
new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
|
|
190
|
+
new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
|
|
191
|
+
|
|
192
|
+
handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
|
|
193
|
+
layout = self.builder.get_gluon_layout_from_memdesc(handle)
|
|
194
|
+
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
|
|
195
|
+
alloc_shape=new_alloc_shape, layout=layout)
|
|
196
|
+
|
|
197
|
+
def memdesc_reshape(self, mem_desc, shape, layout):
|
|
198
|
+
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
|
|
199
|
+
handle = self.builder.create_memdesc_reshape(ty.to_ir(self.builder), mem_desc.handle)
|
|
200
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
201
|
+
|
|
202
|
+
def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
|
|
203
|
+
ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)
|
|
204
|
+
handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle)
|
|
205
|
+
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
|
|
206
|
+
|
|
207
|
+
def wrap_tensor(self, x, scalar_ty, ret_shape, layout):
|
|
208
|
+
if ret_shape:
|
|
209
|
+
res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout)
|
|
210
|
+
else:
|
|
211
|
+
res_ty = scalar_ty
|
|
212
|
+
return self.tensor(x, res_ty)
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def _check_same_layout(xs):
|
|
216
|
+
for x in xs:
|
|
217
|
+
_check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}")
|
|
218
|
+
layouts = [x.type.layout for x in xs]
|
|
219
|
+
l0 = layouts[0]
|
|
220
|
+
_check(all(l == l0 for l in layouts[1:]),
|
|
221
|
+
lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
|
|
222
|
+
|
|
223
|
+
def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
|
|
224
|
+
_check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
|
|
225
|
+
# get result shape
|
|
226
|
+
shape = inputs[0].type.shape
|
|
227
|
+
rank = len(shape)
|
|
228
|
+
_check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
|
|
229
|
+
self._check_same_layout(inputs)
|
|
230
|
+
ret_shape = [s for i, s in enumerate(shape) if i != axis]
|
|
231
|
+
ret_layout = SliceLayout(axis, inputs[0].type.layout)
|
|
232
|
+
assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
|
|
233
|
+
|
|
234
|
+
reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
|
|
235
|
+
region_builder_fn(reduce_op)
|
|
236
|
+
assert reduce_op.verify()
|
|
237
|
+
|
|
238
|
+
return tuple(
|
|
239
|
+
self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
|
|
240
|
+
for i in range(len(inputs)))
|
|
241
|
+
|
|
242
|
+
def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
|
|
243
|
+
worker_num_regs: Sequence[int], generator):
|
|
244
|
+
num_partitions = len(worker_partitions)
|
|
245
|
+
assert num_partitions == len(
|
|
246
|
+
worker_num_warps
|
|
247
|
+
), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"
|
|
248
|
+
assert num_partitions == len(
|
|
249
|
+
worker_num_regs
|
|
250
|
+
), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts"
|
|
251
|
+
|
|
252
|
+
builder = self.builder
|
|
253
|
+
insert_pt = builder.get_insertion_point()
|
|
254
|
+
|
|
255
|
+
# Emit the default partition to get the result types.
|
|
256
|
+
default_block = builder.new_block()
|
|
257
|
+
builder.set_insertion_point_to_start(default_block)
|
|
258
|
+
default_results = generator.call_JitFunction(default_partition, args, kwargs={})
|
|
259
|
+
mlir_results = []
|
|
260
|
+
if default_results is not None:
|
|
261
|
+
mlir_results = flatten_values_to_ir(default_results)
|
|
262
|
+
builder.create_warp_yield(mlir_results)
|
|
263
|
+
result_types = [r.get_type() for r in mlir_results]
|
|
264
|
+
|
|
265
|
+
# Create the warp specialize op.
|
|
266
|
+
builder.restore_insertion_point(insert_pt)
|
|
267
|
+
mlir_args = flatten_values_to_ir(args)
|
|
268
|
+
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
|
|
269
|
+
ws_op.get_default_region().push_back(default_block)
|
|
270
|
+
ws_op.set_requested_registers(worker_num_regs)
|
|
271
|
+
|
|
272
|
+
# Emit the partition regions.
|
|
273
|
+
builder.create_block_with_parent(ws_op.get_partition_op_holder(), [])
|
|
274
|
+
partitions_op = builder.create_warp_specialize_partitions(num_partitions)
|
|
275
|
+
arg_types = [arg.get_type() for arg in mlir_args]
|
|
276
|
+
for i in range(num_partitions):
|
|
277
|
+
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
|
|
278
|
+
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
|
|
279
|
+
block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
|
|
280
|
+
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
|
|
281
|
+
builder.create_warp_return()
|
|
282
|
+
|
|
283
|
+
builder.set_insertion_point_after(ws_op.get_operation())
|
|
284
|
+
mlir_results = [ws_op.get_result(i) for i in range(len(result_types))]
|
|
285
|
+
if default_results is None:
|
|
286
|
+
return
|
|
287
|
+
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language.standard as tl_standard
|
|
4
|
+
from .._runtime import jit
|
|
5
|
+
from triton import knobs
|
|
6
|
+
from . import _core as ttgl
|
|
7
|
+
|
|
8
|
+
_IMPORT_FROM_TRITON = [
|
|
9
|
+
"sum",
|
|
10
|
+
"max",
|
|
11
|
+
"min",
|
|
12
|
+
"reduce_or",
|
|
13
|
+
"xor_sum",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"full_like",
|
|
18
|
+
"zeros",
|
|
19
|
+
"zeros_like",
|
|
20
|
+
*_IMPORT_FROM_TRITON,
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
for name in _IMPORT_FROM_TRITON:
|
|
24
|
+
# Convert JITFunction -> GluonJitFunction
|
|
25
|
+
fn = getattr(tl_standard, name)
|
|
26
|
+
assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction)
|
|
27
|
+
globals()[name] = jit(fn.fn)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@jit
|
|
31
|
+
def zeros(shape, dtype, layout):
|
|
32
|
+
return ttgl.full(shape, 0, dtype, layout)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@jit
|
|
36
|
+
def full_like(input, value, shape=None, dtype=None, layout=None):
|
|
37
|
+
return ttgl.full(
|
|
38
|
+
input.shape if shape is None else shape,
|
|
39
|
+
value,
|
|
40
|
+
input.dtype if dtype is None else dtype,
|
|
41
|
+
input.type.layout if layout is None else layout,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@jit
|
|
46
|
+
def zeros_like(input, shape=None, dtype=None, layout=None):
|
|
47
|
+
return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Optional, Tuple, List, TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from triton.experimental.gluon.language import _core as ttgl
|
|
6
|
+
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
|
|
7
|
+
from triton.experimental.gluon.language._semantic import _check
|
|
8
|
+
|
|
9
|
+
from . import tma
|
|
10
|
+
from ..hopper import mbarrier, fence_async_shared
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from triton._C.libtriton.gluon_ir import GluonOpBuilder
|
|
14
|
+
from triton._C.libtriton import gluon_ir as ir
|
|
15
|
+
from ..._semantic import GluonSemantic
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"allocate_tensor_memory",
|
|
19
|
+
"fence_async_shared",
|
|
20
|
+
"mbarrier",
|
|
21
|
+
"tensor_memory_descriptor",
|
|
22
|
+
"TensorMemoryLayout",
|
|
23
|
+
"tma",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True, eq=True)
|
|
28
|
+
class TensorMemoryLayout:
|
|
29
|
+
block: Tuple[int, int]
|
|
30
|
+
unpacked: bool
|
|
31
|
+
cta_split_num: Optional[Tuple[int, int]] = None
|
|
32
|
+
|
|
33
|
+
def __post_init__(self):
|
|
34
|
+
assert len(self.block) == 2
|
|
35
|
+
assert self.cta_split_num is None or len(self.cta_split_num) == 2
|
|
36
|
+
|
|
37
|
+
def _to_ir(self, builder):
|
|
38
|
+
cta_split_num = self.cta_split_num or [1, 1]
|
|
39
|
+
return builder.get_tensor_memory_layout(
|
|
40
|
+
self.block,
|
|
41
|
+
self.unpacked,
|
|
42
|
+
cta_split_num,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def mangle(self) -> str:
|
|
46
|
+
block_str = f"{self.block[0]}x{self.block[1]}"
|
|
47
|
+
unpacked_str = "U" if self.unpacked else "P"
|
|
48
|
+
cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
|
|
49
|
+
return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class tensor_memory_descriptor_type(base_type):
|
|
53
|
+
|
|
54
|
+
def __init__(self, element_ty, shape, layout, alloc_shape):
|
|
55
|
+
self.element_ty = element_ty
|
|
56
|
+
self.shape = shape
|
|
57
|
+
self.layout = layout
|
|
58
|
+
self.alloc_shape = alloc_shape
|
|
59
|
+
assert isinstance(layout, TensorMemoryLayout)
|
|
60
|
+
|
|
61
|
+
def to_ir(self, builder: GluonOpBuilder) -> None:
|
|
62
|
+
return builder.get_tensor_mem_desc_ty(
|
|
63
|
+
self.element_ty.to_ir(builder),
|
|
64
|
+
self.shape,
|
|
65
|
+
self.layout._to_ir(builder),
|
|
66
|
+
self.alloc_shape,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[tensor_memory_descriptor, int]:
|
|
70
|
+
value = tensor_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
|
|
71
|
+
return value, cursor + 1
|
|
72
|
+
|
|
73
|
+
def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
|
|
74
|
+
out.append(self.to_ir(builder))
|
|
75
|
+
|
|
76
|
+
def __str__(self) -> str:
|
|
77
|
+
return f"tensor_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>"
|
|
78
|
+
|
|
79
|
+
def __eq__(self, other) -> bool:
|
|
80
|
+
return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
|
|
81
|
+
and self.alloc_shape == other.alloc_shape)
|
|
82
|
+
|
|
83
|
+
def __neq__(self, other) -> bool:
|
|
84
|
+
return not (self == other)
|
|
85
|
+
|
|
86
|
+
def mangle(self) -> str:
|
|
87
|
+
shape_str = "_".join([str(s) for s in self.shape])
|
|
88
|
+
return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class tensor_memory_descriptor(base_value):
|
|
92
|
+
|
|
93
|
+
def __init__(self, handle, element_ty, shape, layout, alloc_shape):
|
|
94
|
+
self.handle = handle
|
|
95
|
+
self.type = tensor_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
|
|
96
|
+
|
|
97
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
98
|
+
handles.append(self.handle)
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def dtype(self):
|
|
102
|
+
return self.type.element_ty
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def shape(self):
|
|
106
|
+
return self.type.shape
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def rank(self):
|
|
110
|
+
return len(self.shape)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def layout(self):
|
|
114
|
+
return self.type.layout
|
|
115
|
+
|
|
116
|
+
def __str__(self) -> str:
|
|
117
|
+
return str(self.type)
|
|
118
|
+
|
|
119
|
+
@builtin
|
|
120
|
+
def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor:
|
|
121
|
+
layout = _unwrap_if_constexpr(layout)
|
|
122
|
+
ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout)
|
|
123
|
+
builder = _semantic.builder
|
|
124
|
+
handle = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle)
|
|
125
|
+
return ttgl.tensor(handle, ret_ty)
|
|
126
|
+
|
|
127
|
+
@builtin
|
|
128
|
+
def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
|
|
129
|
+
pred = _unwrap_if_constexpr(pred)
|
|
130
|
+
pred = _semantic.to_tensor(pred)
|
|
131
|
+
_semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
|
|
132
|
+
|
|
133
|
+
@builtin
|
|
134
|
+
def slice(self, start, length, _semantic: GluonSemantic) -> None:
|
|
135
|
+
start = _unwrap_if_constexpr(start)
|
|
136
|
+
length = _unwrap_if_constexpr(length)
|
|
137
|
+
_check(isinstance(start, int), lambda: "start must be a constant int")
|
|
138
|
+
_check(isinstance(length, int), lambda: "length must be a constant int")
|
|
139
|
+
shape = self.shape[:-1] + [length]
|
|
140
|
+
layout = self.type.layout
|
|
141
|
+
layout = TensorMemoryLayout((layout.block[0], min(layout.block[1], length)), layout.unpacked,
|
|
142
|
+
layout.cta_split_num)
|
|
143
|
+
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
|
|
144
|
+
builder = _semantic.builder
|
|
145
|
+
ret.handle = builder.create_tmem_subslice(ret.type.to_ir(builder), self.handle, start)
|
|
146
|
+
return ret
|
|
147
|
+
|
|
148
|
+
@builtin
|
|
149
|
+
def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
|
|
150
|
+
index = _semantic.to_tensor(index)
|
|
151
|
+
builder = _semantic.builder
|
|
152
|
+
offsets = [builder.get_int32(0)] * self.rank
|
|
153
|
+
offsets[0] = index.handle
|
|
154
|
+
shape = self.shape[1:]
|
|
155
|
+
layout = self.layout
|
|
156
|
+
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
|
|
157
|
+
ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets)
|
|
158
|
+
return ret
|
|
159
|
+
|
|
160
|
+
@builtin
|
|
161
|
+
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
|
|
162
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
163
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
164
|
+
layout = _unwrap_if_constexpr(layout)
|
|
165
|
+
|
|
166
|
+
ty = tensor_memory_descriptor_type(dtype, shape, layout, shape)
|
|
167
|
+
handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle)
|
|
168
|
+
return tensor_memory_descriptor(handle, **ty.__dict__)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@builtin
|
|
172
|
+
def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None):
|
|
173
|
+
element_ty = _unwrap_if_constexpr(element_ty)
|
|
174
|
+
shape = _unwrap_if_constexpr(shape)
|
|
175
|
+
layout = _unwrap_if_constexpr(layout)
|
|
176
|
+
value = value.handle if value is not None else None
|
|
177
|
+
|
|
178
|
+
ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape)
|
|
179
|
+
builder = _semantic.builder
|
|
180
|
+
handle = builder.create_tmem_alloc(ty.to_ir(builder), value)
|
|
181
|
+
return tensor_memory_descriptor(handle, element_ty, shape, layout, shape)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@builtin
|
|
185
|
+
def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _semantic=None):
|
|
186
|
+
use_acc = _semantic.to_tensor(use_acc)
|
|
187
|
+
pred = _semantic.to_tensor(pred)
|
|
188
|
+
|
|
189
|
+
if mbarriers is None:
|
|
190
|
+
assert mbarrier_preds is None
|
|
191
|
+
mbarriers = []
|
|
192
|
+
mbarrier_preds = []
|
|
193
|
+
else:
|
|
194
|
+
mbarriers = [bar.handle for bar in mbarriers]
|
|
195
|
+
if mbarrier_preds is None:
|
|
196
|
+
true = _semantic.to_tensor(True)
|
|
197
|
+
mbarrier_preds = [true] * len(mbarriers)
|
|
198
|
+
else:
|
|
199
|
+
mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False)
|
|
200
|
+
|
|
201
|
+
_semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers,
|
|
202
|
+
mbarrier_preds)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from triton.experimental.gluon.language._core import builtin
|
|
2
|
+
from triton.experimental.gluon.language.nvidia.hopper.tma import (
|
|
3
|
+
async_copy_global_to_shared,
|
|
4
|
+
async_copy_shared_to_global,
|
|
5
|
+
store_wait,
|
|
6
|
+
tensor_descriptor,
|
|
7
|
+
tensor_descriptor_type,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"async_gather",
|
|
12
|
+
"async_scatter",
|
|
13
|
+
"async_copy_global_to_shared",
|
|
14
|
+
"async_copy_shared_to_global",
|
|
15
|
+
"store_wait",
|
|
16
|
+
"tensor_descriptor",
|
|
17
|
+
"tensor_descriptor_type",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@builtin
|
|
22
|
+
def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None):
|
|
23
|
+
pred = _semantic.to_tensor(pred)
|
|
24
|
+
y_offset = _semantic.to_tensor(y_offset)
|
|
25
|
+
_semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle,
|
|
26
|
+
result.handle, pred.handle)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@builtin
|
|
30
|
+
def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None):
|
|
31
|
+
y_offset = _semantic.to_tensor(y_offset)
|
|
32
|
+
_semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from . import mbarrier
|
|
2
|
+
from . import tma
|
|
3
|
+
from ... import _core
|
|
4
|
+
|
|
5
|
+
__all__ = ["fence_async_shared", "mbarrier", "tma"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@_core.builtin
|
|
9
|
+
def fence_async_shared(cluster=False, _semantic=None):
|
|
10
|
+
cluster = _core._unwrap_if_constexpr(cluster)
|
|
11
|
+
_semantic.builder.create_fence_async_shared(cluster)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
|
|
2
|
+
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
|
|
3
|
+
|
|
4
|
+
__all__ = ["MBarrierLayout", "init", "invalidate", "expect", "wait", "arrive"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MBarrierLayout(SwizzledSharedLayout):
|
|
8
|
+
|
|
9
|
+
def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
|
|
10
|
+
super().__init__(
|
|
11
|
+
vec=1,
|
|
12
|
+
per_phase=1,
|
|
13
|
+
max_phase=1,
|
|
14
|
+
order=[0],
|
|
15
|
+
ctas_per_cga=[ctas_per_cga],
|
|
16
|
+
cta_split_num=[cta_split_num],
|
|
17
|
+
cta_order=[0],
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@builtin
|
|
22
|
+
def init(mbarrier, count, _semantic=None):
|
|
23
|
+
count = _unwrap_if_constexpr(count)
|
|
24
|
+
_semantic.builder.create_mbarrier_init(mbarrier.handle, count)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@builtin
|
|
28
|
+
def invalidate(mbarrier, _semantic=None):
|
|
29
|
+
_semantic.builder.create_mbarrier_inval(mbarrier.handle)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@builtin
|
|
33
|
+
def expect(mbarrier, bytes, pred=True, _semantic=None):
|
|
34
|
+
bytes = _unwrap_if_constexpr(bytes)
|
|
35
|
+
pred = _semantic.to_tensor(pred)
|
|
36
|
+
_semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@builtin
|
|
40
|
+
def wait(mbarrier, phase, pred=True, deps=(), _semantic=None):
|
|
41
|
+
phase = _semantic.to_tensor(phase)
|
|
42
|
+
pred = _semantic.to_tensor(pred)
|
|
43
|
+
deps = [x.handle for x in deps]
|
|
44
|
+
_semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@builtin
|
|
48
|
+
def arrive(mbarrier, count, pred=True, _semantic=None):
|
|
49
|
+
count = _unwrap_if_constexpr(count)
|
|
50
|
+
pred = _semantic.to_tensor(pred)
|
|
51
|
+
_semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)
|