triton-windows 3.2.0.post12__cp312-cp312-win_amd64.whl → 3.3.0a0.post12__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
|
@@ -29,7 +29,7 @@ def experimental_device_tensormap_create1d(
|
|
|
29
29
|
load_size: core.tensor,
|
|
30
30
|
global_size: core.tensor,
|
|
31
31
|
element_ty: core.dtype,
|
|
32
|
-
_builder: ir.builder,
|
|
32
|
+
_builder: ir.builder = None,
|
|
33
33
|
):
|
|
34
34
|
load_size = core._constexpr_to_value(load_size)
|
|
35
35
|
global_size = semantic.to_tensor(global_size, _builder)
|
|
@@ -58,7 +58,7 @@ def experimental_device_tensormap_create2d(
|
|
|
58
58
|
load_size: Sequence[core.constexpr],
|
|
59
59
|
global_size: Sequence[core.tensor],
|
|
60
60
|
element_ty: core.dtype,
|
|
61
|
-
_builder: ir.builder,
|
|
61
|
+
_builder: ir.builder = None,
|
|
62
62
|
):
|
|
63
63
|
assert len(load_size) == 2
|
|
64
64
|
assert len(global_size) == 2
|
|
@@ -68,8 +68,6 @@ def experimental_device_tensormap_create2d(
|
|
|
68
68
|
element_size = element_ty.primitive_bitwidth // 8
|
|
69
69
|
element_size_t = core.full([], element_size, core.int64, _builder=_builder)
|
|
70
70
|
global_stride = semantic.mul(element_size_t, global_size[-1], True, _builder)
|
|
71
|
-
# Undocumented, but global_stride seems to be divided by 16
|
|
72
|
-
global_stride = semantic.ashr(global_stride, semantic.to_tensor(4, _builder), _builder)
|
|
73
71
|
|
|
74
72
|
contig_dim_size_in_bytes = element_size * load_size[-1]
|
|
75
73
|
if contig_dim_size_in_bytes > 128:
|
|
@@ -104,5 +102,5 @@ def _determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size):
|
|
|
104
102
|
|
|
105
103
|
|
|
106
104
|
@core.builtin
|
|
107
|
-
def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder):
|
|
105
|
+
def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder = None):
|
|
108
106
|
semantic.tensormap_fenceproxy_acquire(desc_ptr, _builder)
|
triton/language/math.py
CHANGED
|
@@ -173,9 +173,9 @@ def rsqrt(x, _builder=None):
|
|
|
173
173
|
return core.tensor(_builder.create_rsqrt(x.handle), x.type)
|
|
174
174
|
|
|
175
175
|
|
|
176
|
+
@core._tensor_member_fn
|
|
176
177
|
@core.builtin
|
|
177
178
|
@_add_math_1arg_docstr("absolute value")
|
|
178
|
-
@core._tensor_member_fn
|
|
179
179
|
def abs(x, _builder=None):
|
|
180
180
|
x = semantic.to_tensor(x, _builder)
|
|
181
181
|
dtype = x.dtype
|
triton/language/random.py
CHANGED
|
@@ -45,11 +45,12 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
|
|
|
45
45
|
@jit
|
|
46
46
|
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
47
47
|
seed = tl.to_tensor(seed)
|
|
48
|
+
tl.static_assert(seed.dtype.is_int())
|
|
49
|
+
seed = seed.to(tl.uint64)
|
|
48
50
|
c0 = tl.to_tensor(c0)
|
|
49
51
|
c1 = tl.to_tensor(c1)
|
|
50
52
|
c2 = tl.to_tensor(c2)
|
|
51
53
|
c3 = tl.to_tensor(c3)
|
|
52
|
-
seed = seed.to(tl.uint64)
|
|
53
54
|
if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
|
|
54
55
|
int_dtype = tl.uint32
|
|
55
56
|
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
triton/language/semantic.py
CHANGED
|
@@ -6,7 +6,6 @@ import numbers
|
|
|
6
6
|
|
|
7
7
|
from .._C.libtriton import ir
|
|
8
8
|
from . import core as tl
|
|
9
|
-
from . import math
|
|
10
9
|
|
|
11
10
|
T = TypeVar('T')
|
|
12
11
|
|
|
@@ -62,7 +61,7 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i
|
|
|
62
61
|
div_or_mod: bool) -> tl.dtype:
|
|
63
62
|
# 0) For scalars we follow semantics similar to PyTorch, namely:
|
|
64
63
|
# - If the scalar is of a lower or equal kind (bool < uint < int < fp),
|
|
65
|
-
# it doesn't participate in the
|
|
64
|
+
# it doesn't participate in the promotion
|
|
66
65
|
if a_is_scalar != b_is_scalar:
|
|
67
66
|
scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
|
|
68
67
|
if scalar_ty.kind().value <= tensor_ty.kind().value:
|
|
@@ -88,11 +87,12 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i
|
|
|
88
87
|
else:
|
|
89
88
|
return tl.float16
|
|
90
89
|
# 4) return bf16 only if both operands are of bf16
|
|
91
|
-
if a_ty.is_bf16()
|
|
90
|
+
if a_ty.is_bf16() and b_ty.is_bf16():
|
|
92
91
|
if div_or_mod:
|
|
93
92
|
return tl.float32
|
|
94
|
-
|
|
93
|
+
else:
|
|
95
94
|
return tl.bfloat16
|
|
95
|
+
if a_ty.is_bf16() or b_ty.is_bf16():
|
|
96
96
|
return tl.float32
|
|
97
97
|
# 5) return fp16 if operands are different fp8
|
|
98
98
|
if a_ty.is_fp8() and b_ty.is_fp8():
|
|
@@ -186,6 +186,11 @@ def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor
|
|
|
186
186
|
or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
|
|
187
187
|
raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
|
|
188
188
|
"Perform a explicit cast on one of them.")
|
|
189
|
+
if ret_sca_ty.is_int():
|
|
190
|
+
if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= ret_sca_ty.get_int_max_value()):
|
|
191
|
+
raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
|
|
192
|
+
if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= ret_sca_ty.get_int_max_value()):
|
|
193
|
+
raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
|
|
189
194
|
lhs = full(
|
|
190
195
|
(), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder)
|
|
191
196
|
rhs = full(
|
|
@@ -230,7 +235,15 @@ def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sa
|
|
|
230
235
|
input_scalar_ty = input.type.scalar
|
|
231
236
|
other_scalar_ty = other.type.scalar
|
|
232
237
|
if input_scalar_ty.is_ptr():
|
|
233
|
-
|
|
238
|
+
other_handle = other.handle
|
|
239
|
+
if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64:
|
|
240
|
+
# addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive
|
|
241
|
+
if other.type.is_block():
|
|
242
|
+
i64_ty = tl.block_type(tl.int64, other.type.get_block_shapes()).to_ir(builder)
|
|
243
|
+
else:
|
|
244
|
+
i64_ty = tl.int64.to_ir(builder)
|
|
245
|
+
other_handle = builder.create_int_cast(other.handle, i64_ty, False)
|
|
246
|
+
return tl.tensor(builder.create_addptr(input.handle, other_handle), input.type)
|
|
234
247
|
# float + float
|
|
235
248
|
elif input_scalar_ty.is_floating():
|
|
236
249
|
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
|
|
@@ -333,10 +346,7 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu
|
|
|
333
346
|
other_scalar_ty = other.type.scalar
|
|
334
347
|
# float % float
|
|
335
348
|
if scalar_ty.is_floating():
|
|
336
|
-
|
|
337
|
-
floor = math.floor(fdiv(input, other, False, builder), _builder=builder)
|
|
338
|
-
ret = sub(input, mul(floor, other, True, builder), True, builder)
|
|
339
|
-
return ret
|
|
349
|
+
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
|
|
340
350
|
# % int
|
|
341
351
|
elif scalar_ty.is_int():
|
|
342
352
|
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
|
|
@@ -762,14 +772,14 @@ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) ->
|
|
|
762
772
|
# Add new axes to lhs
|
|
763
773
|
for _ in range(len(lhs_shape), len(rhs_shape)):
|
|
764
774
|
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
|
|
765
|
-
tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
|
|
775
|
+
tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
|
|
766
776
|
lhs_ty = lhs.type
|
|
767
777
|
lhs_shape = lhs_ty.get_block_shapes()
|
|
768
778
|
elif len(rhs_shape) < len(lhs_shape):
|
|
769
779
|
# Add new axes to rhs
|
|
770
780
|
for _ in range(len(rhs_shape), len(lhs_shape)):
|
|
771
781
|
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
|
|
772
|
-
tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
|
|
782
|
+
tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
|
|
773
783
|
rhs_ty = rhs.type
|
|
774
784
|
rhs_shape = rhs_ty.get_block_shapes()
|
|
775
785
|
assert len(rhs_shape) == len(lhs_shape)
|
|
@@ -831,10 +841,6 @@ def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tenso
|
|
|
831
841
|
def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder,
|
|
832
842
|
fp_downcast_rounding: Optional[str] = None) -> tl.tensor:
|
|
833
843
|
src_ty = input.type
|
|
834
|
-
if isinstance(dst_ty, tl.constexpr):
|
|
835
|
-
dst_ty = dst_ty.value
|
|
836
|
-
if isinstance(fp_downcast_rounding, tl.constexpr):
|
|
837
|
-
fp_downcast_rounding = fp_downcast_rounding.value
|
|
838
844
|
if src_ty.is_block():
|
|
839
845
|
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
|
|
840
846
|
if src_ty == dst_ty:
|
|
@@ -1048,7 +1054,7 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti
|
|
|
1048
1054
|
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
|
|
1049
1055
|
|
|
1050
1056
|
elt_ty = ptr.type.element_ty.element_ty
|
|
1051
|
-
assert elt_ty != tl.int1, "`tl.int1` should be
|
|
1057
|
+
assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
|
|
1052
1058
|
if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
|
|
1053
1059
|
raise ValueError("Padding option `nan` is not supported for integer block pointers")
|
|
1054
1060
|
|
|
@@ -1141,18 +1147,93 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor],
|
|
|
1141
1147
|
return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
|
|
1142
1148
|
|
|
1143
1149
|
|
|
1144
|
-
def
|
|
1150
|
+
def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder):
|
|
1151
|
+
handle = builder.create_reinterpret_tensor_descriptor(desc_ptr.handle, block_ty.to_ir(builder))
|
|
1152
|
+
return tl._experimental_tensor_descriptor_base(handle, block_ty)
|
|
1153
|
+
|
|
1154
|
+
|
|
1155
|
+
def validate_descriptor_block(shape, dtype):
|
|
1156
|
+
if len(shape) != 2:
|
|
1157
|
+
return
|
|
1158
|
+
# Due to limitations of the shared memory encoding, the TMA bounding box has
|
|
1159
|
+
# to be at least as big as the swizzle tile.
|
|
1160
|
+
assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}"
|
|
1161
|
+
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1162
|
+
assert shape[
|
|
1163
|
+
1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}"
|
|
1164
|
+
|
|
1165
|
+
|
|
1166
|
+
def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
|
|
1145
1167
|
builder: ir.builder) -> tl.tensor:
|
|
1168
|
+
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1169
|
+
validate_descriptor_block(desc.block_shape, desc.dtype)
|
|
1170
|
+
ndim = len(desc.block_shape)
|
|
1171
|
+
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
|
|
1172
|
+
|
|
1146
1173
|
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
|
1147
|
-
x = builder.create_descriptor_load(
|
|
1148
|
-
_str_to_load_cache_modifier(cache_modifier),
|
|
1174
|
+
x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier),
|
|
1149
1175
|
_str_to_eviction_policy(eviction_policy))
|
|
1150
|
-
return tl.tensor(x,
|
|
1176
|
+
return tl.tensor(x, desc.block_type)
|
|
1177
|
+
|
|
1151
1178
|
|
|
1179
|
+
def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
|
|
1180
|
+
builder: ir.builder) -> tl.tensor:
|
|
1181
|
+
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1182
|
+
validate_descriptor_block(desc.block_shape, desc.dtype)
|
|
1183
|
+
ndim = len(desc.block_shape)
|
|
1184
|
+
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
|
|
1185
|
+
assert value.shape == desc.block_shape
|
|
1152
1186
|
|
|
1153
|
-
def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
|
|
1154
1187
|
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
|
1155
|
-
return tl.tensor(builder.create_descriptor_store(
|
|
1188
|
+
return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
|
|
1189
|
+
|
|
1190
|
+
|
|
1191
|
+
def descriptor_gather(desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str,
|
|
1192
|
+
builder: ir.builder) -> tl.tensor:
|
|
1193
|
+
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1194
|
+
assert cache_modifier == "", "cache modifier is not supported yet"
|
|
1195
|
+
assert eviction_policy == "", "eviction policy is not supported yet"
|
|
1196
|
+
|
|
1197
|
+
# Validate descriptor.
|
|
1198
|
+
assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
|
|
1199
|
+
assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
|
|
1200
|
+
|
|
1201
|
+
# Validate offsets.
|
|
1202
|
+
assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
|
|
1203
|
+
|
|
1204
|
+
# Validate minimum block size.
|
|
1205
|
+
assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
|
|
1206
|
+
dtype = desc.dtype
|
|
1207
|
+
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1208
|
+
assert desc.block_shape[
|
|
1209
|
+
1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
|
|
1210
|
+
|
|
1211
|
+
type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
|
|
1212
|
+
y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
|
|
1213
|
+
x = builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(builder))
|
|
1214
|
+
return tl.tensor(x, type)
|
|
1215
|
+
|
|
1216
|
+
|
|
1217
|
+
def descriptor_scatter(desc, value: tl.tensor, x_offsets, y_offset, builder: ir.builder) -> tl.tensor:
|
|
1218
|
+
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1219
|
+
|
|
1220
|
+
# Validate descriptor.
|
|
1221
|
+
assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
|
|
1222
|
+
assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
|
|
1223
|
+
|
|
1224
|
+
# Validate offsets.
|
|
1225
|
+
assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
|
|
1226
|
+
|
|
1227
|
+
# Validate minimum block size.
|
|
1228
|
+
assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
|
|
1229
|
+
dtype = desc.dtype
|
|
1230
|
+
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1231
|
+
assert desc.block_shape[
|
|
1232
|
+
1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
|
|
1233
|
+
|
|
1234
|
+
y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
|
|
1235
|
+
builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
|
|
1236
|
+
return tl.tensor(None, tl.void)
|
|
1156
1237
|
|
|
1157
1238
|
|
|
1158
1239
|
def tensormap_create(
|
|
@@ -1206,7 +1287,7 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde
|
|
|
1206
1287
|
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
|
|
1207
1288
|
|
|
1208
1289
|
elt_ty = ptr.type.element_ty.element_ty
|
|
1209
|
-
assert elt_ty != tl.int1, "`tl.int1` should be
|
|
1290
|
+
assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
|
|
1210
1291
|
|
|
1211
1292
|
# Check `boundary_check` argument
|
|
1212
1293
|
boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
|
|
@@ -1256,7 +1337,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
|
|
|
1256
1337
|
val = cast(val, elt_ty, builder)
|
|
1257
1338
|
|
|
1258
1339
|
# Build IR
|
|
1259
|
-
if
|
|
1340
|
+
if mask is None:
|
|
1260
1341
|
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
|
|
1261
1342
|
if not mask.type.scalar.is_bool():
|
|
1262
1343
|
raise ValueError("Mask must have boolean scalar type")
|
|
@@ -1311,7 +1392,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
|
|
|
1311
1392
|
if val is not None:
|
|
1312
1393
|
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
|
|
1313
1394
|
val = cast(val, ptr.type.scalar.element_ty, builder)
|
|
1314
|
-
if
|
|
1395
|
+
if mask is None:
|
|
1315
1396
|
mask_ir = builder.get_int1(True)
|
|
1316
1397
|
mask_ty = tl.int1
|
|
1317
1398
|
if ptr.type.is_block():
|
|
@@ -1470,6 +1551,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
|
|
|
1470
1551
|
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
|
|
1471
1552
|
|
|
1472
1553
|
if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
|
|
1554
|
+
# We upcast because there's no fp8e4b15 type in MLIR
|
|
1473
1555
|
lhs = cast(lhs, tl.float16, builder)
|
|
1474
1556
|
rhs = cast(rhs, tl.float16, builder)
|
|
1475
1557
|
|
|
@@ -1527,40 +1609,58 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
|
|
|
1527
1609
|
ret_ty)
|
|
1528
1610
|
|
|
1529
1611
|
|
|
1530
|
-
def _str_to_fp_type(float_format:
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1612
|
+
def _str_to_fp_type(float_format: str):
|
|
1613
|
+
ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
|
|
1614
|
+
if ty_enum is None:
|
|
1615
|
+
raise ValueError(f"Invalid float format: {float_format}.")
|
|
1616
|
+
return ty_enum
|
|
1617
|
+
|
|
1618
|
+
|
|
1619
|
+
def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder):
|
|
1620
|
+
"""
|
|
1621
|
+
If float_format is subbyte, make sure it's packed as uint8 and return it.
|
|
1622
|
+
Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
|
|
1623
|
+
"""
|
|
1624
|
+
triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format)
|
|
1625
|
+
if triton_ty is None:
|
|
1626
|
+
assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
|
|
1627
|
+
assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
|
|
1628
|
+
return val
|
|
1629
|
+
if val.dtype == triton_ty:
|
|
1630
|
+
return val
|
|
1631
|
+
else:
|
|
1632
|
+
unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format]
|
|
1633
|
+
assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
|
|
1634
|
+
return bitcast(val, triton_ty, builder)
|
|
1542
1635
|
|
|
1543
1636
|
|
|
1544
|
-
def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
|
|
1545
|
-
rhs_format, acc: tl.tensor | None,
|
|
1637
|
+
def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
|
|
1638
|
+
rhs_format: str, acc: tl.tensor | None, fast_math: bool, out_dtype: tl.dtype,
|
|
1639
|
+
builder: ir.builder) -> tl.tensor:
|
|
1546
1640
|
assert lhs.type.is_block() and rhs.type.is_block()
|
|
1547
1641
|
#TODO: validate types.
|
|
1548
1642
|
lhs_rank = len(lhs.shape)
|
|
1549
1643
|
rhs_rank = len(rhs.shape)
|
|
1550
1644
|
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
|
|
1645
|
+
lhs_format: str = lhs_format.value
|
|
1646
|
+
rhs_format: str = rhs_format.value
|
|
1551
1647
|
lhs_format_enum = _str_to_fp_type(lhs_format)
|
|
1552
1648
|
rhs_format_enum = _str_to_fp_type(rhs_format)
|
|
1553
|
-
|
|
1554
|
-
assert
|
|
1555
|
-
|
|
1556
|
-
|
|
1649
|
+
allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
|
|
1650
|
+
assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
|
|
1651
|
+
assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
|
|
1652
|
+
rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
|
|
1653
|
+
lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
|
|
1654
|
+
lhs = _bitcast_to_fp_type(lhs, lhs_format, builder)
|
|
1655
|
+
rhs = _bitcast_to_fp_type(rhs, rhs_format, builder)
|
|
1557
1656
|
|
|
1558
1657
|
M = lhs.type.shape[-2]
|
|
1559
1658
|
K, N = rhs.type.shape[-2:]
|
|
1560
|
-
|
|
1561
|
-
|
|
1659
|
+
PACKED_A = 2 if lhs_format == "e2m1" else 1
|
|
1660
|
+
PACKED_B = 2 if rhs_format == "e2m1" else 1
|
|
1661
|
+
assert K * PACKED_B == PACKED_A * lhs.type.shape[
|
|
1562
1662
|
-1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
|
|
1563
|
-
assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
|
|
1663
|
+
#assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
|
|
1564
1664
|
B = lhs.type.shape[0] if lhs_rank == 3 else None
|
|
1565
1665
|
|
|
1566
1666
|
ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
|
|
@@ -1571,9 +1671,10 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor,
|
|
|
1571
1671
|
acc_handle = acc.handle
|
|
1572
1672
|
assert acc.type == ret_ty
|
|
1573
1673
|
rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
|
|
1674
|
+
lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
|
|
1574
1675
|
return tl.tensor(
|
|
1575
|
-
builder.create_dot_scaled(lhs.handle,
|
|
1576
|
-
rhs_format_enum, acc_handle), ret_ty)
|
|
1676
|
+
builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
|
|
1677
|
+
rhs_format_enum, fast_math, acc_handle), ret_ty)
|
|
1577
1678
|
|
|
1578
1679
|
|
|
1579
1680
|
# ===----------------------------------------------------------------------===//
|
|
@@ -1655,6 +1756,30 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
|
|
|
1655
1756
|
return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
|
|
1656
1757
|
|
|
1657
1758
|
|
|
1759
|
+
# ===----------------------------------------------------------------------===
|
|
1760
|
+
# Gather
|
|
1761
|
+
# ===----------------------------------------------------------------------===
|
|
1762
|
+
|
|
1763
|
+
|
|
1764
|
+
def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
1765
|
+
assert index.dtype.is_int(), "index must be an integer tensor"
|
|
1766
|
+
|
|
1767
|
+
rank = len(src.type.shape)
|
|
1768
|
+
assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
|
|
1769
|
+
|
|
1770
|
+
assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})"
|
|
1771
|
+
if axis < 0:
|
|
1772
|
+
axis += rank
|
|
1773
|
+
|
|
1774
|
+
for d in range(rank):
|
|
1775
|
+
if d == axis:
|
|
1776
|
+
continue
|
|
1777
|
+
assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
|
|
1778
|
+
|
|
1779
|
+
gather = builder.create_gather(src.handle, index.handle, axis)
|
|
1780
|
+
return wrap_tensor(gather, src.type.scalar, index.type.shape)
|
|
1781
|
+
|
|
1782
|
+
|
|
1658
1783
|
# ===----------------------------------------------------------------------===
|
|
1659
1784
|
# Histogram
|
|
1660
1785
|
# ===----------------------------------------------------------------------===
|
|
@@ -1663,10 +1788,7 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
|
|
|
1663
1788
|
def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
|
|
1664
1789
|
assert len(input.shape) == 1, "histogram only supports 1D input"
|
|
1665
1790
|
assert input.dtype.is_int(), "histogram only supports integer input"
|
|
1666
|
-
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32,
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
##
|
|
1791
|
+
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins]))
|
|
1670
1792
|
|
|
1671
1793
|
|
|
1672
1794
|
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
|
@@ -1794,3 +1916,35 @@ def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
|
|
|
1794
1916
|
|
|
1795
1917
|
# Advanced block pointer type is the same as before
|
|
1796
1918
|
return tl.tensor(builder.create_advance(base.handle, offsets), base.type)
|
|
1919
|
+
|
|
1920
|
+
|
|
1921
|
+
def make_tensor_descriptor(
|
|
1922
|
+
base: tl.tensor,
|
|
1923
|
+
shape: List[tl.tensor],
|
|
1924
|
+
strides: List[tl.tensor],
|
|
1925
|
+
block_shape: List[tl.constexpr],
|
|
1926
|
+
builder: ir.builder,
|
|
1927
|
+
) -> tl._experimental_tensor_descriptor:
|
|
1928
|
+
ndim = len(shape)
|
|
1929
|
+
if not (2 <= ndim <= 5):
|
|
1930
|
+
raise ValueError(f"Expected 2 <= ndim <= 5 but got {ndim} dimensions")
|
|
1931
|
+
if len(strides) != ndim:
|
|
1932
|
+
raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
|
|
1933
|
+
if len(block_shape) != ndim:
|
|
1934
|
+
raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
|
|
1935
|
+
|
|
1936
|
+
strides[-1] = tl._constexpr_to_value(strides[-1])
|
|
1937
|
+
if strides[-1] != 1:
|
|
1938
|
+
raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}")
|
|
1939
|
+
|
|
1940
|
+
shape = [to_tensor(x, builder) for x in shape]
|
|
1941
|
+
strides = [to_tensor(x, builder).to(tl.int64, _builder=builder) for x in strides]
|
|
1942
|
+
|
|
1943
|
+
# Check whether `block_shape` is static
|
|
1944
|
+
block_shape = tl._unwrap_shape(block_shape)
|
|
1945
|
+
|
|
1946
|
+
assert isinstance(base.type, tl.pointer_type)
|
|
1947
|
+
type = tl.block_type(base.type.element_ty, block_shape)
|
|
1948
|
+
handle = builder.create_make_tensor_descriptor(base.handle, [s.handle for s in shape], [s.handle for s in strides],
|
|
1949
|
+
block_shape)
|
|
1950
|
+
return tl._experimental_tensor_descriptor(handle, shape, strides, type)
|
triton/language/standard.py
CHANGED
|
@@ -59,14 +59,14 @@ def softmax(x, ieee_rounding=False):
|
|
|
59
59
|
|
|
60
60
|
@core._tensor_member_fn
|
|
61
61
|
@jit
|
|
62
|
-
def ravel(x):
|
|
62
|
+
def ravel(x, can_reorder=False):
|
|
63
63
|
"""
|
|
64
64
|
Returns a contiguous flattened view of :code:`x`.
|
|
65
65
|
|
|
66
66
|
:param x: the input tensor
|
|
67
67
|
:type x: Block
|
|
68
68
|
"""
|
|
69
|
-
return core.reshape(x, [x.numel], can_reorder=
|
|
69
|
+
return core.reshape(x, [x.numel], can_reorder=can_reorder)
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
@jit
|
|
@@ -259,11 +259,30 @@ def _sum_combine(a, b):
|
|
|
259
259
|
# sum
|
|
260
260
|
|
|
261
261
|
|
|
262
|
+
def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr):
|
|
263
|
+
dtype = core._unwrap_if_constexpr(dtype)
|
|
264
|
+
if dtype is not None:
|
|
265
|
+
return dtype
|
|
266
|
+
|
|
267
|
+
# For integer bitwidths less than 32, pick int32 with the same sign to
|
|
268
|
+
# avoid overflow.
|
|
269
|
+
out_dtype = None
|
|
270
|
+
if in_dtype.is_int_signed():
|
|
271
|
+
out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None
|
|
272
|
+
elif in_dtype.is_int_unsigned():
|
|
273
|
+
out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None
|
|
274
|
+
return out_dtype
|
|
275
|
+
|
|
276
|
+
|
|
262
277
|
@core._tensor_member_fn
|
|
263
278
|
@jit
|
|
264
|
-
@core._add_reduction_docstr("sum")
|
|
265
|
-
def sum(input, axis=None, keep_dims=False):
|
|
266
|
-
|
|
279
|
+
@core._add_reduction_docstr("sum", dtype_arg="dtype")
|
|
280
|
+
def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):
|
|
281
|
+
# Pick a default dtype for the reduction if one was not specified.
|
|
282
|
+
out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
|
|
283
|
+
|
|
284
|
+
if out_dtype is not None:
|
|
285
|
+
input = input.to(out_dtype)
|
|
267
286
|
return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
|
|
268
287
|
|
|
269
288
|
|
|
@@ -276,15 +295,11 @@ def _xor_combine(a, b):
|
|
|
276
295
|
|
|
277
296
|
|
|
278
297
|
@core._tensor_member_fn
|
|
279
|
-
@
|
|
298
|
+
@jit
|
|
280
299
|
@core._add_reduction_docstr("xor sum")
|
|
281
|
-
def xor_sum(input, axis=None, keep_dims=False
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
raise ValueError("xor_sum only supported for integers")
|
|
285
|
-
|
|
286
|
-
input = core._promote_bfloat16_to_float32(input, _builder=_builder)
|
|
287
|
-
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator)
|
|
300
|
+
def xor_sum(input, axis=None, keep_dims=False):
|
|
301
|
+
core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers")
|
|
302
|
+
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
|
|
288
303
|
|
|
289
304
|
|
|
290
305
|
# cumsum
|
|
@@ -412,11 +427,13 @@ def flip(x, dim=None):
|
|
|
412
427
|
"""
|
|
413
428
|
core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
|
|
414
429
|
core.static_assert(_is_power_of_two(x.numel))
|
|
415
|
-
#
|
|
416
|
-
#
|
|
430
|
+
# reshape the tensor to have all dimensions be 2.
|
|
431
|
+
# TODO: We shouldn't have to change the dimensions not sorted.
|
|
417
432
|
steps: core.constexpr = _log2(x.numel)
|
|
418
433
|
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
|
|
419
|
-
|
|
434
|
+
|
|
435
|
+
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
436
|
+
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
|
|
420
437
|
y = core.expand_dims(y, start)
|
|
421
438
|
flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
|
|
422
439
|
for i in core.static_range(start, steps):
|
|
@@ -424,8 +441,8 @@ def flip(x, dim=None):
|
|
|
424
441
|
for j in core.static_range(0, steps + 1):
|
|
425
442
|
if j != i and j != i + 1:
|
|
426
443
|
flip2 = core.expand_dims(flip2, j)
|
|
427
|
-
y = sum(y * flip2, i + 1, keep_dims=True)
|
|
428
|
-
x = core.reshape(y, x.shape)
|
|
444
|
+
y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
|
|
445
|
+
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
|
|
429
446
|
return x
|
|
430
447
|
|
|
431
448
|
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import Optional, Protocol
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Buffer(Protocol):
|
|
5
|
+
|
|
6
|
+
def data_ptr(self) -> int:
|
|
7
|
+
...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Allocator(Protocol):
|
|
11
|
+
|
|
12
|
+
def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NullAllocator:
|
|
17
|
+
|
|
18
|
+
def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
|
|
19
|
+
raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " +
|
|
20
|
+
"Use triton.set_allocator to specify an allocator.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_allocator: Allocator = NullAllocator()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def set_allocator(allocator: Allocator):
|
|
27
|
+
"""
|
|
28
|
+
The allocator function is called during kernel launch for kernels that
|
|
29
|
+
require additional global memory workspace.
|
|
30
|
+
"""
|
|
31
|
+
global _allocator
|
|
32
|
+
_allocator = allocator
|