triton-windows 3.2.0.post12__cp312-cp312-win_amd64.whl → 3.3.0a0.post12__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
@@ -29,7 +29,7 @@ def experimental_device_tensormap_create1d(
29
29
  load_size: core.tensor,
30
30
  global_size: core.tensor,
31
31
  element_ty: core.dtype,
32
- _builder: ir.builder,
32
+ _builder: ir.builder = None,
33
33
  ):
34
34
  load_size = core._constexpr_to_value(load_size)
35
35
  global_size = semantic.to_tensor(global_size, _builder)
@@ -58,7 +58,7 @@ def experimental_device_tensormap_create2d(
58
58
  load_size: Sequence[core.constexpr],
59
59
  global_size: Sequence[core.tensor],
60
60
  element_ty: core.dtype,
61
- _builder: ir.builder,
61
+ _builder: ir.builder = None,
62
62
  ):
63
63
  assert len(load_size) == 2
64
64
  assert len(global_size) == 2
@@ -68,8 +68,6 @@ def experimental_device_tensormap_create2d(
68
68
  element_size = element_ty.primitive_bitwidth // 8
69
69
  element_size_t = core.full([], element_size, core.int64, _builder=_builder)
70
70
  global_stride = semantic.mul(element_size_t, global_size[-1], True, _builder)
71
- # Undocumented, but global_stride seems to be divided by 16
72
- global_stride = semantic.ashr(global_stride, semantic.to_tensor(4, _builder), _builder)
73
71
 
74
72
  contig_dim_size_in_bytes = element_size * load_size[-1]
75
73
  if contig_dim_size_in_bytes > 128:
@@ -104,5 +102,5 @@ def _determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size):
104
102
 
105
103
 
106
104
  @core.builtin
107
- def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder):
105
+ def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder = None):
108
106
  semantic.tensormap_fenceproxy_acquire(desc_ptr, _builder)
triton/language/math.py CHANGED
@@ -173,9 +173,9 @@ def rsqrt(x, _builder=None):
173
173
  return core.tensor(_builder.create_rsqrt(x.handle), x.type)
174
174
 
175
175
 
176
+ @core._tensor_member_fn
176
177
  @core.builtin
177
178
  @_add_math_1arg_docstr("absolute value")
178
- @core._tensor_member_fn
179
179
  def abs(x, _builder=None):
180
180
  x = semantic.to_tensor(x, _builder)
181
181
  dtype = x.dtype
triton/language/random.py CHANGED
@@ -45,11 +45,12 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
45
45
  @jit
46
46
  def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
47
47
  seed = tl.to_tensor(seed)
48
+ tl.static_assert(seed.dtype.is_int())
49
+ seed = seed.to(tl.uint64)
48
50
  c0 = tl.to_tensor(c0)
49
51
  c1 = tl.to_tensor(c1)
50
52
  c2 = tl.to_tensor(c2)
51
53
  c3 = tl.to_tensor(c3)
52
- seed = seed.to(tl.uint64)
53
54
  if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
54
55
  int_dtype = tl.uint32
55
56
  seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
@@ -6,7 +6,6 @@ import numbers
6
6
 
7
7
  from .._C.libtriton import ir
8
8
  from . import core as tl
9
- from . import math
10
9
 
11
10
  T = TypeVar('T')
12
11
 
@@ -62,7 +61,7 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i
62
61
  div_or_mod: bool) -> tl.dtype:
63
62
  # 0) For scalars we follow semantics similar to PyTorch, namely:
64
63
  # - If the scalar is of a lower or equal kind (bool < uint < int < fp),
65
- # it doesn't participate in the pomotion
64
+ # it doesn't participate in the promotion
66
65
  if a_is_scalar != b_is_scalar:
67
66
  scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
68
67
  if scalar_ty.kind().value <= tensor_ty.kind().value:
@@ -88,11 +87,12 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i
88
87
  else:
89
88
  return tl.float16
90
89
  # 4) return bf16 only if both operands are of bf16
91
- if a_ty.is_bf16() or b_ty.is_bf16():
90
+ if a_ty.is_bf16() and b_ty.is_bf16():
92
91
  if div_or_mod:
93
92
  return tl.float32
94
- if a_ty.is_bf16() and b_ty.is_bf16():
93
+ else:
95
94
  return tl.bfloat16
95
+ if a_ty.is_bf16() or b_ty.is_bf16():
96
96
  return tl.float32
97
97
  # 5) return fp16 if operands are different fp8
98
98
  if a_ty.is_fp8() and b_ty.is_fp8():
@@ -186,6 +186,11 @@ def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor
186
186
  or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
187
187
  raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
188
188
  "Perform a explicit cast on one of them.")
189
+ if ret_sca_ty.is_int():
190
+ if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= ret_sca_ty.get_int_max_value()):
191
+ raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
192
+ if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= ret_sca_ty.get_int_max_value()):
193
+ raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
189
194
  lhs = full(
190
195
  (), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder)
191
196
  rhs = full(
@@ -230,7 +235,15 @@ def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sa
230
235
  input_scalar_ty = input.type.scalar
231
236
  other_scalar_ty = other.type.scalar
232
237
  if input_scalar_ty.is_ptr():
233
- return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
238
+ other_handle = other.handle
239
+ if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64:
240
+ # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive
241
+ if other.type.is_block():
242
+ i64_ty = tl.block_type(tl.int64, other.type.get_block_shapes()).to_ir(builder)
243
+ else:
244
+ i64_ty = tl.int64.to_ir(builder)
245
+ other_handle = builder.create_int_cast(other.handle, i64_ty, False)
246
+ return tl.tensor(builder.create_addptr(input.handle, other_handle), input.type)
234
247
  # float + float
235
248
  elif input_scalar_ty.is_floating():
236
249
  return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
@@ -333,10 +346,7 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu
333
346
  other_scalar_ty = other.type.scalar
334
347
  # float % float
335
348
  if scalar_ty.is_floating():
336
- # input - input.div(other, rounding_mode="floor") * other
337
- floor = math.floor(fdiv(input, other, False, builder), _builder=builder)
338
- ret = sub(input, mul(floor, other, True, builder), True, builder)
339
- return ret
349
+ return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
340
350
  # % int
341
351
  elif scalar_ty.is_int():
342
352
  if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
@@ -762,14 +772,14 @@ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) ->
762
772
  # Add new axes to lhs
763
773
  for _ in range(len(lhs_shape), len(rhs_shape)):
764
774
  lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
765
- tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
775
+ tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
766
776
  lhs_ty = lhs.type
767
777
  lhs_shape = lhs_ty.get_block_shapes()
768
778
  elif len(rhs_shape) < len(lhs_shape):
769
779
  # Add new axes to rhs
770
780
  for _ in range(len(rhs_shape), len(lhs_shape)):
771
781
  rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
772
- tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
782
+ tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
773
783
  rhs_ty = rhs.type
774
784
  rhs_shape = rhs_ty.get_block_shapes()
775
785
  assert len(rhs_shape) == len(lhs_shape)
@@ -831,10 +841,6 @@ def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tenso
831
841
  def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder,
832
842
  fp_downcast_rounding: Optional[str] = None) -> tl.tensor:
833
843
  src_ty = input.type
834
- if isinstance(dst_ty, tl.constexpr):
835
- dst_ty = dst_ty.value
836
- if isinstance(fp_downcast_rounding, tl.constexpr):
837
- fp_downcast_rounding = fp_downcast_rounding.value
838
844
  if src_ty.is_block():
839
845
  dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
840
846
  if src_ty == dst_ty:
@@ -1048,7 +1054,7 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti
1048
1054
  raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1049
1055
 
1050
1056
  elt_ty = ptr.type.element_ty.element_ty
1051
- assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
1057
+ assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1052
1058
  if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
1053
1059
  raise ValueError("Padding option `nan` is not supported for integer block pointers")
1054
1060
 
@@ -1141,18 +1147,93 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor],
1141
1147
  return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1142
1148
 
1143
1149
 
1144
- def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type,
1150
+ def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder):
1151
+ handle = builder.create_reinterpret_tensor_descriptor(desc_ptr.handle, block_ty.to_ir(builder))
1152
+ return tl._experimental_tensor_descriptor_base(handle, block_ty)
1153
+
1154
+
1155
+ def validate_descriptor_block(shape, dtype):
1156
+ if len(shape) != 2:
1157
+ return
1158
+ # Due to limitations of the shared memory encoding, the TMA bounding box has
1159
+ # to be at least as big as the swizzle tile.
1160
+ assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}"
1161
+ min_cols = 32 // dtype.primitive_bitwidth * 8
1162
+ assert shape[
1163
+ 1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}"
1164
+
1165
+
1166
+ def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
1145
1167
  builder: ir.builder) -> tl.tensor:
1168
+ assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1169
+ validate_descriptor_block(desc.block_shape, desc.dtype)
1170
+ ndim = len(desc.block_shape)
1171
+ assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1172
+
1146
1173
  offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1147
- x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder),
1148
- _str_to_load_cache_modifier(cache_modifier),
1174
+ x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier),
1149
1175
  _str_to_eviction_policy(eviction_policy))
1150
- return tl.tensor(x, type)
1176
+ return tl.tensor(x, desc.block_type)
1177
+
1151
1178
 
1179
+ def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
1180
+ builder: ir.builder) -> tl.tensor:
1181
+ assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1182
+ validate_descriptor_block(desc.block_shape, desc.dtype)
1183
+ ndim = len(desc.block_shape)
1184
+ assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1185
+ assert value.shape == desc.block_shape
1152
1186
 
1153
- def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1154
1187
  offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1155
- return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void)
1188
+ return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
1189
+
1190
+
1191
+ def descriptor_gather(desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str,
1192
+ builder: ir.builder) -> tl.tensor:
1193
+ assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1194
+ assert cache_modifier == "", "cache modifier is not supported yet"
1195
+ assert eviction_policy == "", "eviction policy is not supported yet"
1196
+
1197
+ # Validate descriptor.
1198
+ assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1199
+ assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1200
+
1201
+ # Validate offsets.
1202
+ assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
1203
+
1204
+ # Validate minimum block size.
1205
+ assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
1206
+ dtype = desc.dtype
1207
+ min_cols = 32 // dtype.primitive_bitwidth * 8
1208
+ assert desc.block_shape[
1209
+ 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1210
+
1211
+ type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
1212
+ y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
1213
+ x = builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(builder))
1214
+ return tl.tensor(x, type)
1215
+
1216
+
1217
+ def descriptor_scatter(desc, value: tl.tensor, x_offsets, y_offset, builder: ir.builder) -> tl.tensor:
1218
+ assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1219
+
1220
+ # Validate descriptor.
1221
+ assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1222
+ assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1223
+
1224
+ # Validate offsets.
1225
+ assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
1226
+
1227
+ # Validate minimum block size.
1228
+ assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
1229
+ dtype = desc.dtype
1230
+ min_cols = 32 // dtype.primitive_bitwidth * 8
1231
+ assert desc.block_shape[
1232
+ 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1233
+
1234
+ y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
1235
+ builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
1236
+ return tl.tensor(None, tl.void)
1156
1237
 
1157
1238
 
1158
1239
  def tensormap_create(
@@ -1206,7 +1287,7 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde
1206
1287
  assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
1207
1288
 
1208
1289
  elt_ty = ptr.type.element_ty.element_ty
1209
- assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
1290
+ assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1210
1291
 
1211
1292
  # Check `boundary_check` argument
1212
1293
  boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
@@ -1256,7 +1337,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
1256
1337
  val = cast(val, elt_ty, builder)
1257
1338
 
1258
1339
  # Build IR
1259
- if not mask:
1340
+ if mask is None:
1260
1341
  return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
1261
1342
  if not mask.type.scalar.is_bool():
1262
1343
  raise ValueError("Mask must have boolean scalar type")
@@ -1311,7 +1392,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
1311
1392
  if val is not None:
1312
1393
  val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
1313
1394
  val = cast(val, ptr.type.scalar.element_ty, builder)
1314
- if not mask:
1395
+ if mask is None:
1315
1396
  mask_ir = builder.get_int1(True)
1316
1397
  mask_ty = tl.int1
1317
1398
  if ptr.type.is_block():
@@ -1470,6 +1551,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
1470
1551
  assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
1471
1552
 
1472
1553
  if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1554
+ # We upcast because there's no fp8e4b15 type in MLIR
1473
1555
  lhs = cast(lhs, tl.float16, builder)
1474
1556
  rhs = cast(rhs, tl.float16, builder)
1475
1557
 
@@ -1527,40 +1609,58 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
1527
1609
  ret_ty)
1528
1610
 
1529
1611
 
1530
- def _str_to_fp_type(float_format: Optional[str]):
1531
- if float_format == 'e4m3':
1532
- return ir.F8F6F4TY.E4M3
1533
- if float_format == 'e5m2':
1534
- return ir.F8F6F4TY.E5M2
1535
- if float_format == 'e2m3':
1536
- return ir.F8F6F4TY.E2M3
1537
- if float_format == 'e3m2':
1538
- return ir.F8F6F4TY.E3M2
1539
- if float_format == 'e2m1':
1540
- return ir.F8F6F4TY.E2M1
1541
- raise ValueError(f"Invalid float format: {float_format}.")
1612
+ def _str_to_fp_type(float_format: str):
1613
+ ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
1614
+ if ty_enum is None:
1615
+ raise ValueError(f"Invalid float format: {float_format}.")
1616
+ return ty_enum
1617
+
1618
+
1619
+ def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder):
1620
+ """
1621
+ If float_format is subbyte, make sure it's packed as uint8 and return it.
1622
+ Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
1623
+ """
1624
+ triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format)
1625
+ if triton_ty is None:
1626
+ assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
1627
+ assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
1628
+ return val
1629
+ if val.dtype == triton_ty:
1630
+ return val
1631
+ else:
1632
+ unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format]
1633
+ assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
1634
+ return bitcast(val, triton_ty, builder)
1542
1635
 
1543
1636
 
1544
- def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
1545
- rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
1637
+ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
1638
+ rhs_format: str, acc: tl.tensor | None, fast_math: bool, out_dtype: tl.dtype,
1639
+ builder: ir.builder) -> tl.tensor:
1546
1640
  assert lhs.type.is_block() and rhs.type.is_block()
1547
1641
  #TODO: validate types.
1548
1642
  lhs_rank = len(lhs.shape)
1549
1643
  rhs_rank = len(rhs.shape)
1550
1644
  assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1645
+ lhs_format: str = lhs_format.value
1646
+ rhs_format: str = rhs_format.value
1551
1647
  lhs_format_enum = _str_to_fp_type(lhs_format)
1552
1648
  rhs_format_enum = _str_to_fp_type(rhs_format)
1553
- assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}"
1554
- assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}"
1555
- rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None
1556
- assert rhs_scale_is_none, "NYI: rhs_scale not supported"
1649
+ allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
1650
+ assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
1651
+ assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
1652
+ rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
1653
+ lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
1654
+ lhs = _bitcast_to_fp_type(lhs, lhs_format, builder)
1655
+ rhs = _bitcast_to_fp_type(rhs, rhs_format, builder)
1557
1656
 
1558
1657
  M = lhs.type.shape[-2]
1559
1658
  K, N = rhs.type.shape[-2:]
1560
- PACKED = 2 if lhs_format == "e2m1" else 1
1561
- assert K == PACKED * lhs.type.shape[
1659
+ PACKED_A = 2 if lhs_format == "e2m1" else 1
1660
+ PACKED_B = 2 if rhs_format == "e2m1" else 1
1661
+ assert K * PACKED_B == PACKED_A * lhs.type.shape[
1562
1662
  -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1563
- assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
1663
+ #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
1564
1664
  B = lhs.type.shape[0] if lhs_rank == 3 else None
1565
1665
 
1566
1666
  ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
@@ -1571,9 +1671,10 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor,
1571
1671
  acc_handle = acc.handle
1572
1672
  assert acc.type == ret_ty
1573
1673
  rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
1674
+ lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
1574
1675
  return tl.tensor(
1575
- builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
1576
- rhs_format_enum, acc_handle), ret_ty)
1676
+ builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
1677
+ rhs_format_enum, fast_math, acc_handle), ret_ty)
1577
1678
 
1578
1679
 
1579
1680
  # ===----------------------------------------------------------------------===//
@@ -1655,6 +1756,30 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
1655
1756
  return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
1656
1757
 
1657
1758
 
1759
+ # ===----------------------------------------------------------------------===
1760
+ # Gather
1761
+ # ===----------------------------------------------------------------------===
1762
+
1763
+
1764
+ def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
1765
+ assert index.dtype.is_int(), "index must be an integer tensor"
1766
+
1767
+ rank = len(src.type.shape)
1768
+ assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
1769
+
1770
+ assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})"
1771
+ if axis < 0:
1772
+ axis += rank
1773
+
1774
+ for d in range(rank):
1775
+ if d == axis:
1776
+ continue
1777
+ assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
1778
+
1779
+ gather = builder.create_gather(src.handle, index.handle, axis)
1780
+ return wrap_tensor(gather, src.type.scalar, index.type.shape)
1781
+
1782
+
1658
1783
  # ===----------------------------------------------------------------------===
1659
1784
  # Histogram
1660
1785
  # ===----------------------------------------------------------------------===
@@ -1663,10 +1788,7 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn,
1663
1788
  def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
1664
1789
  assert len(input.shape) == 1, "histogram only supports 1D input"
1665
1790
  assert input.dtype.is_int(), "histogram only supports integer input"
1666
- return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, )))
1667
-
1668
-
1669
- ##
1791
+ return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins]))
1670
1792
 
1671
1793
 
1672
1794
  def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
@@ -1794,3 +1916,35 @@ def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1794
1916
 
1795
1917
  # Advanced block pointer type is the same as before
1796
1918
  return tl.tensor(builder.create_advance(base.handle, offsets), base.type)
1919
+
1920
+
1921
+ def make_tensor_descriptor(
1922
+ base: tl.tensor,
1923
+ shape: List[tl.tensor],
1924
+ strides: List[tl.tensor],
1925
+ block_shape: List[tl.constexpr],
1926
+ builder: ir.builder,
1927
+ ) -> tl._experimental_tensor_descriptor:
1928
+ ndim = len(shape)
1929
+ if not (2 <= ndim <= 5):
1930
+ raise ValueError(f"Expected 2 <= ndim <= 5 but got {ndim} dimensions")
1931
+ if len(strides) != ndim:
1932
+ raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
1933
+ if len(block_shape) != ndim:
1934
+ raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
1935
+
1936
+ strides[-1] = tl._constexpr_to_value(strides[-1])
1937
+ if strides[-1] != 1:
1938
+ raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}")
1939
+
1940
+ shape = [to_tensor(x, builder) for x in shape]
1941
+ strides = [to_tensor(x, builder).to(tl.int64, _builder=builder) for x in strides]
1942
+
1943
+ # Check whether `block_shape` is static
1944
+ block_shape = tl._unwrap_shape(block_shape)
1945
+
1946
+ assert isinstance(base.type, tl.pointer_type)
1947
+ type = tl.block_type(base.type.element_ty, block_shape)
1948
+ handle = builder.create_make_tensor_descriptor(base.handle, [s.handle for s in shape], [s.handle for s in strides],
1949
+ block_shape)
1950
+ return tl._experimental_tensor_descriptor(handle, shape, strides, type)
@@ -59,14 +59,14 @@ def softmax(x, ieee_rounding=False):
59
59
 
60
60
  @core._tensor_member_fn
61
61
  @jit
62
- def ravel(x):
62
+ def ravel(x, can_reorder=False):
63
63
  """
64
64
  Returns a contiguous flattened view of :code:`x`.
65
65
 
66
66
  :param x: the input tensor
67
67
  :type x: Block
68
68
  """
69
- return core.reshape(x, [x.numel], can_reorder=True)
69
+ return core.reshape(x, [x.numel], can_reorder=can_reorder)
70
70
 
71
71
 
72
72
  @jit
@@ -259,11 +259,30 @@ def _sum_combine(a, b):
259
259
  # sum
260
260
 
261
261
 
262
+ def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr):
263
+ dtype = core._unwrap_if_constexpr(dtype)
264
+ if dtype is not None:
265
+ return dtype
266
+
267
+ # For integer bitwidths less than 32, pick int32 with the same sign to
268
+ # avoid overflow.
269
+ out_dtype = None
270
+ if in_dtype.is_int_signed():
271
+ out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None
272
+ elif in_dtype.is_int_unsigned():
273
+ out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None
274
+ return out_dtype
275
+
276
+
262
277
  @core._tensor_member_fn
263
278
  @jit
264
- @core._add_reduction_docstr("sum")
265
- def sum(input, axis=None, keep_dims=False):
266
- input = core._promote_bfloat16_to_float32(input)
279
+ @core._add_reduction_docstr("sum", dtype_arg="dtype")
280
+ def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):
281
+ # Pick a default dtype for the reduction if one was not specified.
282
+ out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
283
+
284
+ if out_dtype is not None:
285
+ input = input.to(out_dtype)
267
286
  return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
268
287
 
269
288
 
@@ -276,15 +295,11 @@ def _xor_combine(a, b):
276
295
 
277
296
 
278
297
  @core._tensor_member_fn
279
- @core.builtin
298
+ @jit
280
299
  @core._add_reduction_docstr("xor sum")
281
- def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None):
282
- scalar_ty = input.type.scalar
283
- if not scalar_ty.is_int():
284
- raise ValueError("xor_sum only supported for integers")
285
-
286
- input = core._promote_bfloat16_to_float32(input, _builder=_builder)
287
- return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator)
300
+ def xor_sum(input, axis=None, keep_dims=False):
301
+ core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers")
302
+ return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
288
303
 
289
304
 
290
305
  # cumsum
@@ -412,11 +427,13 @@ def flip(x, dim=None):
412
427
  """
413
428
  core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
414
429
  core.static_assert(_is_power_of_two(x.numel))
415
- # # reshape the tensor to have all dimensions be 2.
416
- # # TODO: We shouldn't have to change the dimensions not sorted.
430
+ # reshape the tensor to have all dimensions be 2.
431
+ # TODO: We shouldn't have to change the dimensions not sorted.
417
432
  steps: core.constexpr = _log2(x.numel)
418
433
  start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
419
- y = core.reshape(x, [2] * steps)
434
+
435
+ idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
436
+ y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
420
437
  y = core.expand_dims(y, start)
421
438
  flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
422
439
  for i in core.static_range(start, steps):
@@ -424,8 +441,8 @@ def flip(x, dim=None):
424
441
  for j in core.static_range(0, steps + 1):
425
442
  if j != i and j != i + 1:
426
443
  flip2 = core.expand_dims(flip2, j)
427
- y = sum(y * flip2, i + 1, keep_dims=True)
428
- x = core.reshape(y, x.shape)
444
+ y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
445
+ x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
429
446
  return x
430
447
 
431
448
 
@@ -0,0 +1,32 @@
1
+ from typing import Optional, Protocol
2
+
3
+
4
+ class Buffer(Protocol):
5
+
6
+ def data_ptr(self) -> int:
7
+ ...
8
+
9
+
10
+ class Allocator(Protocol):
11
+
12
+ def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
13
+ ...
14
+
15
+
16
+ class NullAllocator:
17
+
18
+ def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
19
+ raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " +
20
+ "Use triton.set_allocator to specify an allocator.")
21
+
22
+
23
+ _allocator: Allocator = NullAllocator()
24
+
25
+
26
+ def set_allocator(allocator: Allocator):
27
+ """
28
+ The allocator function is called during kernel launch for kernels that
29
+ require additional global memory workspace.
30
+ """
31
+ global _allocator
32
+ _allocator = allocator