triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  import ast
3
3
  import textwrap
4
4
  import inspect
5
- from typing import Tuple, List, Dict
5
+ from typing import Tuple, List, Dict, Callable
6
6
 
7
7
  import math
8
8
  import numpy as np
@@ -77,17 +77,19 @@ class BlockPointerHandle:
77
77
  class TensorDescHandle:
78
78
 
79
79
  def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
80
- block_shape: List[int]):
80
+ block_shape: List[int], padding):
81
81
  self.base = base
82
82
  self.ndim = len(shape)
83
83
  self.shape = shape
84
84
  self.strides = strides
85
85
  self.block_shape = block_shape
86
+ self.padding = padding
86
87
 
87
88
  def validate(self):
88
89
  assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
89
90
  assert len(self.strides) == self.ndim
90
91
  assert len(self.block_shape) == self.ndim
92
+ assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
91
93
 
92
94
  for stride in self.strides[:-1]:
93
95
  assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
@@ -663,6 +665,9 @@ class InterpreterBuilder:
663
665
  else: # scalar
664
666
  return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
665
667
 
668
+ def create_unsplat(self, arg):
669
+ return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
670
+
666
671
  def create_atomic_cas(self, ptr, cmp, val, sem, scope):
667
672
  if sem not in self.ir_sem_to_interpreter_sem:
668
673
  raise ValueError(f"unsupported semantic {sem}")
@@ -725,15 +730,9 @@ class InterpreterBuilder:
725
730
  ret.offsets[i].data += offsets[i].data
726
731
  return ret
727
732
 
728
- def create_make_tensor_descriptor(
729
- self,
730
- base: TensorHandle,
731
- shape: List[TensorHandle],
732
- strides: List[TensorHandle],
733
- tensor_shape: List[int],
734
- is_signed: bool,
735
- ):
736
- desc = TensorDescHandle(base, shape, strides, tensor_shape)
733
+ def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
734
+ tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
735
+ desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
737
736
  desc.validate()
738
737
  return desc
739
738
 
@@ -741,7 +740,16 @@ class InterpreterBuilder:
741
740
  eviction_policy):
742
741
  assert isinstance(desc, TensorDescHandle)
743
742
  ptrs, mask = desc.materialize_pointers(indices)
744
- return self.create_masked_load(ptrs, mask, other=None, cache_modifier=cache_modifier,
743
+ dtype_tt = ptrs.get_element_ty()
744
+ dtype_np = _get_np_dtype(dtype_tt)
745
+ padding = desc.padding
746
+ if padding == _ir.PADDING_OPTION.PAD_ZERO:
747
+ other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
748
+ elif padding == _ir.PADDING_OPTION.PAD_NAN:
749
+ other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
750
+ else:
751
+ raise ValueError(f"unsupported padding {padding}")
752
+ return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
745
753
  eviction_policy=eviction_policy, is_volatile=False)
746
754
 
747
755
  def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
@@ -934,9 +942,9 @@ class ReduceOps(ReduceScanOpInterface):
934
942
  elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
935
943
  return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
936
944
  elif self.combine_fn == tl.standard._elementwise_max:
937
- return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None)
945
+ return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
938
946
  elif self.combine_fn == tl.standard._elementwise_min:
939
- return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None)
947
+ return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
940
948
  elif self.combine_fn == tl.standard._sum_combine:
941
949
  return self.sum(input[0])
942
950
  else:
@@ -1125,7 +1133,7 @@ def _tuple_create(arg, contents):
1125
1133
  # TODO: wrap everything in triton tensors
1126
1134
  def _implicit_cvt(arg):
1127
1135
  if isinstance(arg, int):
1128
- ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
1136
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
1129
1137
  dtype = np.int32
1130
1138
  if -2**31 <= arg < 2**31:
1131
1139
  dtype = np.int32
@@ -1140,7 +1148,7 @@ def _implicit_cvt(arg):
1140
1148
  handle = TensorHandle(np.array([arg], dtype=dtype), ty)
1141
1149
  return tl.tensor(handle, ty)
1142
1150
  if hasattr(arg, "data_ptr"):
1143
- ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
1151
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
1144
1152
  handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
1145
1153
  return tl.tensor(handle, ty)
1146
1154
  elif isinstance(arg, tuple):
@@ -1150,12 +1158,10 @@ def _implicit_cvt(arg):
1150
1158
  assert arg.strides[-1] == 1
1151
1159
  strides[-1] = tl.constexpr(1)
1152
1160
  semantic = TritonSemantic(InterpreterBuilder())
1153
- return semantic.make_tensor_descriptor(
1154
- base=_implicit_cvt(arg.base),
1155
- shape=[_implicit_cvt(s) for s in arg.shape],
1156
- strides=strides,
1157
- block_shape=[tl.constexpr(b) for b in arg.block_shape],
1158
- )
1161
+ return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
1162
+ shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
1163
+ block_shape=[tl.constexpr(b)
1164
+ for b in arg.block_shape], padding_option=arg.padding)
1159
1165
  return arg
1160
1166
 
1161
1167
 
@@ -1198,6 +1204,7 @@ class GridExecutor:
1198
1204
  arg.shape,
1199
1205
  arg.strides,
1200
1206
  arg.block_shape,
1207
+ arg.padding,
1201
1208
  )
1202
1209
  elif not hasattr(arg, "data_ptr"):
1203
1210
  return arg
@@ -1368,11 +1375,12 @@ class FunctionRewriter:
1368
1375
 
1369
1376
  class InterpretedFunction:
1370
1377
  # Cache all rewritten functions
1371
- rewritten_fn = {}
1378
+ rewritten_fn: Dict[Callable, Callable] = {}
1372
1379
 
1373
1380
  def __init__(self, fn, **kwargs) -> None:
1374
1381
  self.fn = fn
1375
1382
  self.rewriter = FunctionRewriter(fn, **kwargs)
1383
+ self.kwargs = kwargs
1376
1384
 
1377
1385
  def run(*args, **kwargs):
1378
1386
  grid = kwargs["grid"]