triton-windows 3.2.0.post12__cp313-cp313-win_amd64.whl → 3.3.0a0.post12__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/language/core.py CHANGED
@@ -5,7 +5,7 @@ from contextlib import contextmanager
5
5
  from enum import Enum
6
6
  from functools import partial, wraps
7
7
  import typing
8
- from typing import Union, Callable, List, Sequence, TypeVar, Optional
8
+ from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
9
9
  import builtins
10
10
  from ..runtime.jit import jit
11
11
  import inspect
@@ -29,7 +29,6 @@ def builtin(fn: T) -> T:
29
29
  @wraps(fn)
30
30
  def wrapper(*args, **kwargs):
31
31
  if "_builder" not in kwargs or kwargs["_builder"] is None:
32
- print(kwargs)
33
32
  raise ValueError("Did you forget to add @triton.jit ? "
34
33
  "(`_builder` argument must be provided outside of JIT functions.)")
35
34
  return fn(*args, **kwargs)
@@ -141,6 +140,7 @@ class constexpr:
141
140
  self.value = value.value
142
141
  else:
143
142
  self.value = value
143
+ self.type = constexpr
144
144
 
145
145
  def __repr__(self) -> str:
146
146
  return f"constexpr[{self.value}]"
@@ -280,12 +280,40 @@ def check_bit_width(value, shift_value):
280
280
  )
281
281
 
282
282
 
283
+ class base_value:
284
+ """Base class of values that exist in the triton IR (i.e. not constexprs).
285
+ """
286
+ type: base_type
287
+
288
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
289
+ """Flatten frontend value into a sequence of mlir handles, which are appended
290
+ to the output list
291
+ """
292
+ raise NotImplementedError
293
+
294
+
295
+ class base_type:
296
+
297
+ def __eq__(self, other):
298
+ raise NotImplementedError("Types must implement __eq__")
299
+
300
+ def __ne__(self, other):
301
+ return not (self == other)
302
+
303
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
304
+ """Build a frontend value with the current dtype, wrapping a list of existing handles.
305
+ cursor is the index of the first handle relevant to this value, and the function
306
+ should return the updated cursor position after any handles consumed by the created value.
307
+ """
308
+ raise NotImplementedError
309
+
310
+
283
311
  # -----------------------
284
312
  # dtype
285
313
  # -----------------------
286
314
 
287
315
 
288
- class dtype:
316
+ class dtype(base_type):
289
317
  SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
290
318
  UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
291
319
  FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
@@ -474,14 +502,15 @@ class dtype:
474
502
  def is_const():
475
503
  return False
476
504
 
505
+ @staticmethod
506
+ def is_tuple():
507
+ return False
508
+
477
509
  def __eq__(self, other: dtype):
478
510
  if not isinstance(other, dtype):
479
511
  return False
480
512
  return self.name == other.name
481
513
 
482
- def __ne__(self, other: dtype):
483
- return not self.__eq__(other)
484
-
485
514
  def __hash__(self):
486
515
  return hash((self.name, ))
487
516
 
@@ -549,6 +578,9 @@ class dtype:
549
578
  """Output of repr needs to be an evaluatable expression"""
550
579
  return f'triton.language.{self.codegen_name()}'
551
580
 
581
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
582
+ return tensor(handles[cursor], self), cursor + 1
583
+
552
584
 
553
585
  # Some functions have a param named `dtype`, which shadows the `dtype` class.
554
586
  # We can't change the param name because it is part of function's public API.
@@ -587,9 +619,6 @@ class pointer_type(dtype):
587
619
  return False
588
620
  return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
589
621
 
590
- def __ne__(self, other: pointer_type) -> bool:
591
- return not self.__eq__(other)
592
-
593
622
  @property
594
623
  def scalar(self):
595
624
  return self
@@ -609,9 +638,10 @@ class block_type(dtype):
609
638
 
610
639
  # Note that block_type's shape is a list of int
611
640
  # while tensor's shape is a list of constexpr.
641
+ assert (isinstance(shape, (list, tuple)))
612
642
 
613
643
  # shape can be empty ([]) when an input is a 0D tensor.
614
- self.shape = _unwrap_shape(shape)
644
+ self.shape = tuple(_unwrap_shape(shape))
615
645
  if not self.shape:
616
646
  raise TypeError('0d block_type is forbidden')
617
647
 
@@ -633,32 +663,53 @@ class block_type(dtype):
633
663
  def get_block_shapes(self) -> List[int]:
634
664
  return self.shape
635
665
 
636
- def __eq__(self, other: block_type) -> bool:
666
+ def __eq__(self, other) -> bool:
637
667
  if not isinstance(other, block_type):
638
668
  return False
639
669
  return self.element_ty == other.element_ty and self.shape == other.shape
640
670
 
641
- def __ne__(self, other: block_type) -> bool:
642
- return not self.__eq__(other)
643
-
644
671
  @property
645
672
  def scalar(self):
646
673
  return self.element_ty
647
674
 
648
675
 
649
- class function_type(dtype):
676
+ class tuple_type(base_type):
650
677
 
651
- def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
652
- self.ret_types = ret_types
653
- self.param_types = param_types
678
+ def __init__(self, types, fields=None):
679
+ self.types = types
680
+ self.fields = fields or [''] * len(types)
681
+ self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']'
654
682
 
655
683
  def __str__(self):
656
- return f'fn ({self.param_types}) -> {self.ret_types}'
684
+ return self.name
685
+
686
+ def __iter__(self):
687
+ return iter(self.types)
657
688
 
658
689
  def to_ir(self, builder: ir.builder):
659
- ir_param_types = [ty.to_ir(builder) for ty in self.param_types]
660
- ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
661
- return builder.get_function_ty(ir_param_types, ret_types)
690
+ return [ty.to_ir(builder) for ty in self.types]
691
+
692
+ def __getitem__(self, index: int) -> dtype:
693
+ return self.types[index]
694
+
695
+ def is_tuple(self):
696
+ return True
697
+
698
+ def __eq__(self, other):
699
+ return type(self) is type(other) and self.types == other.types and self.fields == other.fields
700
+
701
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]:
702
+ values = []
703
+ for ty in self.types:
704
+ value, cursor = ty._unflatten_ir(handles, cursor)
705
+ values.append(value)
706
+ return tuple(values, self), cursor
707
+
708
+
709
+ class slice_type(dtype):
710
+
711
+ def __init__(self):
712
+ self.name = 'slice_type'
662
713
 
663
714
 
664
715
  # scalar types
@@ -708,20 +759,12 @@ def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
708
759
  raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
709
760
 
710
761
 
711
- class _value:
712
- """Base class of values that exist in the triton IR (i.e. not constexprs).
713
- """
714
-
715
- def __init__(self, handle):
716
- self.handle = handle
717
-
718
-
719
762
  # -----------------------
720
763
  # tensor
721
764
  # -----------------------
722
765
 
723
766
 
724
- class tensor(_value):
767
+ class tensor(base_value):
725
768
  """Represents an N-dimensional array of values or pointers.
726
769
 
727
770
  :code:`tensor` is the fundamental data structure in Triton programs. Most
@@ -743,8 +786,9 @@ class tensor(_value):
743
786
 
744
787
  def __init__(self, handle, type: dtype):
745
788
  """Not called by user code."""
789
+ super().__init__()
746
790
  # IR handle
747
- super().__init__(handle)
791
+ self.handle = handle
748
792
  # Block shape
749
793
  self.shape = type.shape if type.is_block() else ()
750
794
  self.numel = 1
@@ -754,7 +798,10 @@ class tensor(_value):
754
798
  self.type = type # Tensor type (can be block_type)
755
799
  # Following the practice in pytorch, dtype is scalar type
756
800
  self.dtype = type.scalar
757
- self.shape = [constexpr(s) for s in self.shape]
801
+ self.shape = tuple([constexpr(s) for s in self.shape])
802
+
803
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
804
+ handles.append(self.handle)
758
805
 
759
806
  def __str__(self) -> str:
760
807
  # ex. "float32[16, 32]"
@@ -968,13 +1015,16 @@ class tensor(_value):
968
1015
 
969
1016
  @builtin
970
1017
  def __getitem__(self, slices, _builder=None):
971
- if isinstance(slices, (slice, constexpr)) or slices is None:
1018
+ import builtins
1019
+ if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
972
1020
  slices = [slices]
1021
+ if isinstance(slices, tuple):
1022
+ slices = slices.values
973
1023
  ret = self
974
1024
  for dim, sl in enumerate(slices):
975
1025
  if sl is None or isinstance(sl, constexpr) and sl.value is None:
976
1026
  ret = semantic.expand_dims(ret, dim, _builder)
977
- elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
1027
+ elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None:
978
1028
  pass
979
1029
  else:
980
1030
  raise ValueError(f"unsupported tensor index: {sl}")
@@ -990,13 +1040,7 @@ class tensor(_value):
990
1040
  """
991
1041
  Alias for :py:func:`tensor.cast`.
992
1042
  """
993
- # Triton doesn't like core functions calling other core functions, so we
994
- # just copy-paste the implementation of cast here. It's not too bad.
995
- dtype = _unwrap_if_constexpr(dtype)
996
- bitcast = _unwrap_if_constexpr(bitcast)
997
- if bitcast:
998
- return semantic.bitcast(self, dtype, _builder)
999
- return semantic.cast(self, dtype, _builder, fp_downcast_rounding)
1043
+ return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder)
1000
1044
 
1001
1045
  # Type stubs for functions added by the _tensor_member_fn decorator.
1002
1046
  # (Unfortunately these can't be created automatically.)
@@ -1084,6 +1128,9 @@ class tensor(_value):
1084
1128
  def associative_scan(self, axis, combine_fn, reverse=False) -> tensor:
1085
1129
  ...
1086
1130
 
1131
+ def gather(self, indices, axis) -> tensor:
1132
+ ...
1133
+
1087
1134
  def histogram(self, num_bins) -> tensor:
1088
1135
  ...
1089
1136
 
@@ -1111,7 +1158,7 @@ class tensor(_value):
1111
1158
  def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
1112
1159
  ...
1113
1160
 
1114
- def sum(self, axis=None, keep_dims=False) -> tensor:
1161
+ def sum(self, axis=None, keep_dims=False, dtype=None) -> tensor:
1115
1162
  ...
1116
1163
 
1117
1164
  def xor_sum(self, axis=None, keep_dims=False) -> tensor:
@@ -1130,6 +1177,223 @@ class tensor(_value):
1130
1177
  ...
1131
1178
 
1132
1179
 
1180
+ class tuple(base_value):
1181
+
1182
+ def __init__(self, args: list, type: tuple_type = None):
1183
+ self.values = [i for i in args]
1184
+
1185
+ def get_type(x):
1186
+ if isinstance(x, dtype):
1187
+ return dtype
1188
+ if isinstance(x, int):
1189
+ return constexpr
1190
+ return x.type
1191
+
1192
+ self.type = type or tuple_type([get_type(x) for x in self.values])
1193
+
1194
+ def __getitem__(self, idx: constexpr):
1195
+ if isinstance(idx, int):
1196
+ idx = constexpr(idx)
1197
+ if isinstance(idx, constexpr):
1198
+ return self.values[idx]
1199
+ else:
1200
+ import builtins
1201
+ assert isinstance(idx, (slice, builtins.slice))
1202
+ return tuple(self.values[idx.start:idx.stop:idx.step])
1203
+
1204
+ def __getattr__(self, name):
1205
+ return self.values[self.type.fields.index(name)]
1206
+
1207
+ # TODO: remove
1208
+ def __setitem__(self, idx: constexpr, value):
1209
+ if isinstance(idx, int):
1210
+ idx = constexpr(idx)
1211
+ assert isinstance(idx, constexpr)
1212
+ self.values[idx] = value
1213
+
1214
+ def __add__(self, other):
1215
+ if isinstance(other, list):
1216
+ other = tuple(other)
1217
+ return tuple(self.values + other.values)
1218
+ # return tuple(a + b for a, b in zip(self.values, other.values))
1219
+
1220
+ def __mul__(self, other):
1221
+ assert isinstance(other, constexpr)
1222
+ return tuple(self.values * other.value)
1223
+
1224
+ def __eq__(self, other):
1225
+ import builtins
1226
+ if isinstance(other, (list, builtins.tuple)):
1227
+ other = tuple(other)
1228
+ return constexpr(self.values == other.values)
1229
+
1230
+ def __hash__(self):
1231
+ import builtins
1232
+ return hash(builtins.tuple(self.values))
1233
+
1234
+ def __str__(self):
1235
+ return str([str(x) for x in self.values])
1236
+
1237
+ def __iter__(self):
1238
+ return iter(self.values)
1239
+
1240
+ def __len__(self):
1241
+ return len(self.values)
1242
+
1243
+ def _flatten_ir(self, handles: List[ir.value]):
1244
+ for v in self.values:
1245
+ v._flatten_ir(handles)
1246
+
1247
+
1248
+ class slice:
1249
+
1250
+ def __init__(self, start, stop, step):
1251
+ self.start = start
1252
+ self.stop = stop
1253
+ self.step = step
1254
+ self.type = slice_type()
1255
+
1256
+
1257
+ class tensor_descriptor_base_type(base_type):
1258
+
1259
+ def __init__(self, block_type: block_type):
1260
+ self.block_type = block_type
1261
+
1262
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
1263
+ value = _experimental_tensor_descriptor_base(handles[cursor], self.block_type)
1264
+ return value, cursor + 1
1265
+
1266
+ def to_ir(self, builder: ir.builder):
1267
+ return builder.create_tensor_descriptor_type(self.block_type.to_ir(builder))
1268
+
1269
+ def __str__(self) -> str:
1270
+ # ex. "tensor_descriptor<float32[16, 32]>"
1271
+ return f"tensor_descriptor<{self.block_type}>"
1272
+
1273
+ def __eq__(self, other) -> bool:
1274
+ if type(other) is not type(self):
1275
+ return False
1276
+ return self.block_type == other.block_type
1277
+
1278
+ def __neq__(self, other) -> bool:
1279
+ return not (self == other)
1280
+
1281
+
1282
+ class _experimental_tensor_descriptor_base(base_value):
1283
+ """"
1284
+ A tensor descriptor with unknown shape and strides
1285
+ """
1286
+
1287
+ def __init__(self, handle, block_type: block_type):
1288
+ """Not called by user code."""
1289
+ super().__init__()
1290
+
1291
+ self.handle = handle # IR handle
1292
+ self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type)
1293
+
1294
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
1295
+ handles.append(self.handle)
1296
+
1297
+ @property
1298
+ def block_type(self):
1299
+ return self.type.block_type
1300
+
1301
+ @property
1302
+ def block_shape(self):
1303
+ return self.type.block_type.shape
1304
+
1305
+ @property
1306
+ def dtype(self):
1307
+ return self.type.block_type.element_ty
1308
+
1309
+ def __str__(self) -> str:
1310
+ return str(self.type)
1311
+
1312
+ @builtin
1313
+ def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor:
1314
+ """Load a block from the descriptor starting at the given element offsets.
1315
+
1316
+ Values outside of the tensor bounds will be filled with zeros.
1317
+
1318
+ :note: Offset must be a multiple of 16-bytes
1319
+ """
1320
+ return semantic.descriptor_load(self, offsets, "", "", _builder)
1321
+
1322
+ @builtin
1323
+ def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor:
1324
+ """Store a block from the descriptor starting at the given element offsets.
1325
+
1326
+ Values outside of the tensor bounds will be ignored.
1327
+
1328
+ :note: Offset must be a multiple of 16-bytes
1329
+ """
1330
+ return semantic.descriptor_store(self, value, offsets, _builder)
1331
+
1332
+ @builtin
1333
+ def gather(self, *args, _builder=None) -> tensor:
1334
+ """Gather multiple descriptors worth of data"""
1335
+ assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
1336
+ x_offsets = args[0]
1337
+ y_offset = args[1]
1338
+ return semantic.descriptor_gather(self, x_offsets, y_offset, "", "", _builder)
1339
+
1340
+ @builtin
1341
+ def scatter(self, value, *args, _builder=None) -> tensor:
1342
+ """Scatter multiple descriptors worth of data"""
1343
+ assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
1344
+ x_offsets = args[0]
1345
+ y_offset = args[1]
1346
+ return semantic.descriptor_scatter(self, value, x_offsets, y_offset, _builder)
1347
+
1348
+
1349
+ class tensor_descriptor_type(tensor_descriptor_base_type):
1350
+
1351
+ def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type):
1352
+ self.block_type = block_type
1353
+ self.shape_type = shape_type
1354
+ self.strides_type = strides_type
1355
+
1356
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
1357
+ handle = handles[cursor]
1358
+ cursor += 1
1359
+ shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
1360
+ strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
1361
+ shape = shape.values
1362
+ strides = strides.values
1363
+ value = _experimental_tensor_descriptor(handle, shape, strides, self.block_type)
1364
+ return value, cursor
1365
+
1366
+ def to_ir(self, builder: ir.builder):
1367
+ return [super().to_ir(builder), *self.shape_type.to_ir(builder), *self.strides_type.to_ir(builder)]
1368
+
1369
+ def __eq__(self, other):
1370
+ return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
1371
+ == other.strides_type)
1372
+
1373
+
1374
+ class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
1375
+ """A descriptor representing a tensor in global memory.
1376
+ """
1377
+
1378
+ def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type):
1379
+ """Not called by user code."""
1380
+ # IR handle
1381
+ super().__init__(handle, block_type)
1382
+ self.type = tensor_descriptor_type(
1383
+ block_type,
1384
+ shape_type=tuple_type([s.type for s in shape]),
1385
+ strides_type=tuple_type([s.type for s in strides]),
1386
+ )
1387
+ # Global shape
1388
+ self.shape = shape
1389
+ self.strides = strides
1390
+
1391
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
1392
+ handles.append(self.handle)
1393
+ handles.extend(s.handle for s in self.shape)
1394
+ handles.extend(s.handle for s in self.strides)
1395
+
1396
+
1133
1397
  def get_bool_env_var(var_name):
1134
1398
  v = os.getenv(var_name, "0")
1135
1399
  return v == "1" or v == "true" or v == "on"
@@ -1290,6 +1554,7 @@ def trans(input: tensor, *dims, _builder=None):
1290
1554
  :py:func:`permute` is equivalent to this function, except it doesn't
1291
1555
  have the special case when no permutation is specified.
1292
1556
  """
1557
+ dims = _unwrap_iterable(dims)
1293
1558
  if not dims:
1294
1559
  dims = (1, 0)
1295
1560
  return semantic.permute(input, dims, _builder)
@@ -1467,7 +1732,7 @@ def expand_dims(input, axis, _builder=None):
1467
1732
  """
1468
1733
  input = semantic.to_tensor(input, _builder)
1469
1734
  axis = _constexpr_to_value(axis)
1470
- axes = list(axis) if isinstance(axis, Sequence) else [axis]
1735
+ axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
1471
1736
  new_ndim = len(input.shape) + len(axes)
1472
1737
  axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
1473
1738
 
@@ -1499,8 +1764,9 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
1499
1764
  :type bitcast: bool, optional
1500
1765
  """
1501
1766
  input = semantic.to_tensor(input, _builder)
1502
- if isinstance(bitcast, constexpr):
1503
- bitcast = bitcast.value
1767
+ dtype = _constexpr_to_value(dtype)
1768
+ fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding)
1769
+ bitcast = _constexpr_to_value(bitcast)
1504
1770
  if bitcast:
1505
1771
  return semantic.bitcast(input, dtype, _builder)
1506
1772
  return semantic.cast(input, dtype, _builder, fp_downcast_rounding)
@@ -1522,16 +1788,16 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
1522
1788
  where the first dimension of each block represents the batch dimension.
1523
1789
 
1524
1790
  :param input: The first tensor to be multiplied.
1525
- :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1791
+ :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1526
1792
  :param other: The second tensor to be multiplied.
1527
- :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1793
+ :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1528
1794
  :param acc: The accumulator tensor. If not None, the result is added to this tensor.
1529
1795
  :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
1530
1796
  :param input_precision: How to exercise the Tensor Cores for f32 x f32. If
1531
1797
  the device does not have Tensor Cores or the inputs are not of dtype f32,
1532
1798
  this option is ignored. For devices that do have tensor cores, the
1533
1799
  default precision is tf32.
1534
- :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`.
1800
+ :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
1535
1801
  :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
1536
1802
  Only one of :code:`input_precision` and :code:`allow_tf32` can be
1537
1803
  specified (i.e. at least one must be :code:`None`).
@@ -1549,26 +1815,39 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
1549
1815
 
1550
1816
 
1551
1817
  @builtin
1552
- def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None):
1818
+ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, out_dtype=float32,
1819
+ _builder=None):
1553
1820
  """
1554
1821
  Returns the matrix product of two blocks in microscaling format.
1822
+
1555
1823
  lhs and rhs use microscaling formats described here:
1556
1824
  https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
1825
+
1826
+ Software emulation enables targeting hardware architectures without native microscaling
1827
+ operation support. Right now for such case, microscaled lhs/rhs are upcasted to
1828
+ :code:`bf16` element type beforehand for dot computation, with one exception:
1829
+ for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type,
1830
+ the other input is also upcasted to :code:`fp16` element type instead.
1831
+ This behavior is experimental and may be subject to change in the future.
1832
+
1557
1833
  :param lhs: The first tensor to be multiplied.
1558
- :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
1834
+ :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
1559
1835
  :param lhs_scale: Scale factor for lhs tensor.
1560
- :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
1561
- :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
1836
+ :type lhs_scale: e8m0 type represented as an uint8 tensor.
1837
+ :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
1838
+ :type lhs_format: str
1562
1839
  :param rhs: The second tensor to be multiplied.
1563
- :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
1840
+ :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
1564
1841
  :param rhs_scale: Scale factor for rhs tensor.
1565
- :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
1566
- :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
1842
+ :type rhs_scale: e8m0 type represented as an uint8 tensor.
1843
+ :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
1844
+ :type rhs_format: str
1567
1845
  :param acc: The accumulator tensor. If not None, the result is added to this tensor.
1568
1846
  """
1569
1847
  out_dtype = _constexpr_to_value(out_dtype)
1570
1848
  assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
1571
- return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, _builder)
1849
+ return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, out_dtype,
1850
+ _builder)
1572
1851
 
1573
1852
 
1574
1853
  # -----------------------
@@ -1636,6 +1915,16 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
1636
1915
  volatile, _builder)
1637
1916
 
1638
1917
 
1918
+ @builtin
1919
+ def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype,
1920
+ _builder=None) -> _experimental_tensor_descriptor_base:
1921
+ """
1922
+ Reinterpret a generic pointer as a TMA-backed tensor descriptor object.
1923
+ """
1924
+ block_ty = block_type(_constexpr_to_value(dtype), block_shape)
1925
+ return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder)
1926
+
1927
+
1639
1928
  @builtin
1640
1929
  def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None):
1641
1930
  """
@@ -1644,8 +1933,8 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=
1644
1933
 
1645
1934
  This loads a tensor of data based on the descriptor and offsets.
1646
1935
  """
1647
- type = block_type(_constexpr_to_value(dtype), shape)
1648
- return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)
1936
+ desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder)
1937
+ return desc.load(offsets, _builder=_builder)
1649
1938
 
1650
1939
 
1651
1940
  @builtin
@@ -1656,7 +1945,8 @@ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None):
1656
1945
 
1657
1946
  This stores a tensor of data based on the descriptor and offsets.
1658
1947
  """
1659
- return semantic.descriptor_store(desc_pointer, value, offsets, _builder)
1948
+ desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder)
1949
+ return desc.store(offsets, value, _builder=_builder)
1660
1950
 
1661
1951
 
1662
1952
  @_tensor_member_fn
@@ -1737,6 +2027,64 @@ def advance(base, offsets, _builder=None):
1737
2027
  return semantic.advance(base, offsets, _builder)
1738
2028
 
1739
2029
 
2030
+ @builtin
2031
+ def _experimental_make_tensor_descriptor(
2032
+ base: tensor,
2033
+ shape: List[tensor],
2034
+ strides: List[tensor],
2035
+ block_shape: List[constexpr],
2036
+ _builder=None,
2037
+ ) -> _experimental_tensor_descriptor:
2038
+ """Make an experimental tensor descriptor object
2039
+
2040
+ :param base: the base pointer of the tensor, must be 16-byte aligned
2041
+ :param shape: A list of non-negative integers representing the tensor shape
2042
+ :param strides: A list of tensor strides. Leading dimensions must be multiples
2043
+ of 16-byte strides and the last dimension must be contiguous.
2044
+ :param block_shape: The shape of block to be loaded/stored from global memory
2045
+
2046
+ Notes
2047
+ *****
2048
+ On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object
2049
+ and loads and stores from the descriptor will be backed by the TMA hardware.
2050
+
2051
+ Currently only 2-5 dimensional tensors are supported.
2052
+
2053
+ Example
2054
+ *******
2055
+ .. code-block:: python
2056
+
2057
+ @triton.jit
2058
+ def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
2059
+ desc = tl._experimental_make_tensor_descriptor(
2060
+ in_out_ptr,
2061
+ shape=[M, N],
2062
+ strides=[N, 1],
2063
+ block_shape=[M_BLOCK, N_BLOCK],
2064
+ )
2065
+
2066
+ moffset = tl.program_id(0) * M_BLOCK
2067
+ noffset = tl.program_id(1) * N_BLOCK
2068
+
2069
+ value = desc.load([moffset, noffset])
2070
+ desc.store([moffset, noffset], tl.abs(value))
2071
+
2072
+ # TMA descriptors require a global memory allocation
2073
+ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
2074
+ return torch.empty(size, device="cuda", dtype=torch.int8)
2075
+
2076
+ triton.set_allocator(alloc_fn)
2077
+
2078
+ M, N = 256, 256
2079
+ x = torch.randn(M, N, device="cuda")
2080
+ M_BLOCK, N_BLOCK = 32, 32
2081
+ grid = (M / M_BLOCK, N / N_BLOCK)
2082
+ inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
2083
+
2084
+ """
2085
+ return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder)
2086
+
2087
+
1740
2088
  # -----------------------
1741
2089
  # Atomic Memory Operations
1742
2090
  # -----------------------
@@ -1994,7 +2342,8 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No
1994
2342
  # -----------------------
1995
2343
 
1996
2344
 
1997
- def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
2345
+ def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None,
2346
+ dtype_arg: str = None) -> Callable[[T], T]:
1998
2347
 
1999
2348
  def _decorator(func: T) -> T:
2000
2349
  docstr = """
@@ -2014,6 +2363,10 @@ def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_a
2014
2363
  docstr += f"""
2015
2364
  :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN
2016
2365
  :type {tie_break_arg}: bool"""
2366
+ if dtype_arg is not None:
2367
+ docstr += f"""
2368
+ :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`.
2369
+ :type {dtype_arg}: tl.dtype"""
2017
2370
 
2018
2371
  func.__doc__ = docstr.format(name=name)
2019
2372
  return func
@@ -2047,14 +2400,12 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N
2047
2400
  return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0]
2048
2401
 
2049
2402
  def make_combine_region(reduce_op):
2050
- in_scalar_tys = [t.type.scalar for t in input]
2051
- prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
2052
-
2403
+ param_types = [t.type.scalar for t in input] * 2
2053
2404
  region = reduce_op.get_region(0)
2054
2405
  with _insertion_guard(_builder):
2055
- param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
2056
- block = _builder.create_block_with_parent(region, param_types)
2057
- args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
2406
+ to_ir = lambda T: T.to_ir(_builder)
2407
+ block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2408
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2058
2409
  results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2059
2410
  if isinstance(results, tensor):
2060
2411
  handles = [results.handle]
@@ -2148,14 +2499,12 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen
2148
2499
  return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0]
2149
2500
 
2150
2501
  def make_combine_region(scan_op):
2151
- in_scalar_tys = [t.type.scalar for t in input]
2152
- prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
2153
-
2502
+ param_types = [t.type.scalar for t in input] * 2
2154
2503
  region = scan_op.get_region(0)
2155
2504
  with _insertion_guard(_builder):
2156
- param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
2157
- block = _builder.create_block_with_parent(region, param_types)
2158
- args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
2505
+ to_ir = lambda T: T.to_ir(_builder)
2506
+ block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2507
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2159
2508
  results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2160
2509
  if isinstance(results, tensor):
2161
2510
  handles = [results.handle]
@@ -2184,6 +2533,23 @@ def histogram(input, num_bins, _builder=None, _generator=None):
2184
2533
  return semantic.histogram(input, num_bins, _builder)
2185
2534
 
2186
2535
 
2536
+ @_tensor_member_fn
2537
+ @builtin
2538
+ def gather(src, index, axis, _builder=None):
2539
+ """Gather from a tensor along a given dimension.
2540
+
2541
+ :param src: the source tensor
2542
+ :type src: Tensor
2543
+ :param index: the index tensor
2544
+ :type index: Tensor
2545
+ :param axis: the dimension to gather along
2546
+ :type axis: int
2547
+
2548
+ """
2549
+ axis = _constexpr_to_value(axis)
2550
+ return semantic.gather(src, index, axis, _builder)
2551
+
2552
+
2187
2553
  # -----------------------
2188
2554
  # Compiler Hint Ops
2189
2555
  # -----------------------
@@ -2565,9 +2931,15 @@ class range:
2565
2931
  :param loop_unroll_factor: Tells the Triton IR level loop unroller how many
2566
2932
  times to unroll a for loop that this range is used with. Less than 2 for
2567
2933
  this value implies no unrolling.
2934
+ :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot
2935
+ operation in the loop to be multi-buffered, if applicable.
2936
+ :param flatten: automatically flatten the loop nest starting at this loop to
2937
+ create a single flattened loop. The compiler will try to pipeline the
2938
+ flattened loop which can avoid stage stalling.
2568
2939
  """
2569
2940
 
2570
- def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None):
2941
+ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
2942
+ disallow_acc_multi_buffer=False, flatten=False):
2571
2943
  if step is None:
2572
2944
  self.step = constexpr(1)
2573
2945
  else:
@@ -2580,6 +2952,8 @@ class range:
2580
2952
  self.end = arg2
2581
2953
  self.num_stages = num_stages
2582
2954
  self.loop_unroll_factor = loop_unroll_factor
2955
+ self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
2956
+ self.flatten = flatten
2583
2957
 
2584
2958
  def __iter__(self):
2585
2959
  raise RuntimeError("tl.range can only be used in @triton.jit'd functions")