triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/runtime/driver.py CHANGED
@@ -1,60 +1,38 @@
1
- from ..backends import backends
2
- from ..backends import DriverBase
1
+ from __future__ import annotations
3
2
 
3
+ from ..backends import backends, DriverBase
4
4
 
5
- def _create_driver():
6
- actives = [x.driver for x in backends.values() if x.driver.is_active()]
7
- if len(actives) != 1:
8
- raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
9
- return actives[0]()
10
5
 
6
+ def _create_driver() -> DriverBase:
7
+ active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
8
+ if len(active_drivers) != 1:
9
+ raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
10
+ return active_drivers[0]()
11
11
 
12
- class LazyProxy:
13
12
 
14
- def __init__(self, init_fn):
15
- self._init_fn = init_fn
16
- self._obj = None
17
-
18
- def _initialize_obj(self):
19
- if self._obj is None:
20
- self._obj = self._init_fn()
21
-
22
- def __getattr__(self, name):
23
- self._initialize_obj()
24
- return getattr(self._obj, name)
25
-
26
- def __setattr__(self, name, value):
27
- if name in ["_init_fn", "_obj"]:
28
- super().__setattr__(name, value)
29
- else:
30
- self._initialize_obj()
31
- setattr(self._obj, name, value)
32
-
33
- def __delattr__(self, name):
34
- self._initialize_obj()
35
- delattr(self._obj, name)
36
-
37
- def __repr__(self):
38
- if self._obj is None:
39
- return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
40
- return repr(self._obj)
41
-
42
- def __str__(self):
43
- self._initialize_obj()
44
- return str(self._obj)
13
+ class DriverConfig:
45
14
 
15
+ def __init__(self) -> None:
16
+ self._default: DriverBase | None = None
17
+ self._active: DriverBase | None = None
46
18
 
47
- class DriverConfig:
19
+ @property
20
+ def default(self) -> DriverBase:
21
+ if self._default is None:
22
+ self._default = _create_driver()
23
+ return self._default
48
24
 
49
- def __init__(self):
50
- self.default = LazyProxy(_create_driver)
51
- self.active = self.default
25
+ @property
26
+ def active(self) -> DriverBase:
27
+ if self._active is None:
28
+ self._active = self.default
29
+ return self._active
52
30
 
53
- def set_active(self, driver: DriverBase):
54
- self.active = driver
31
+ def set_active(self, driver: DriverBase) -> None:
32
+ self._active = driver
55
33
 
56
- def reset_active(self):
57
- self.active = self.default
34
+ def reset_active(self) -> None:
35
+ self._active = self.default
58
36
 
59
37
 
60
38
  driver = DriverConfig()
@@ -1,32 +1,36 @@
1
+ from __future__ import annotations
1
2
  import ast
2
3
  import textwrap
3
4
  import inspect
4
- from typing import Tuple, List
5
+ from typing import Tuple, List, Dict, Callable
5
6
 
6
7
  import math
7
8
  import numpy as np
8
9
 
9
10
  import triton
10
11
  import triton.language as tl
12
+ import dataclasses
11
13
  from dataclasses import dataclass
14
+
15
+ from triton.language.semantic import TritonSemantic
16
+ from triton.tools.tensor_descriptor import TensorDescriptor
12
17
  from .errors import InterpreterError
13
18
  from functools import partial
14
19
  from .._C.libtriton import interpreter as _interpreter
15
20
  from .._C.libtriton import ir as _ir
16
21
 
17
22
 
23
+ @dataclass
18
24
  class TensorHandle:
19
-
20
- def __init__(self, data, dtype):
21
- '''
22
- data: numpy array
23
- dtype: triton type, either pointer_type or scalar_type.
24
- we don't store block_type here because the shape information is already available in the data field
25
- attr: a dictionary of attributes
26
- '''
27
- self.data = data
28
- self.dtype = dtype
29
- self.attr = {}
25
+ '''
26
+ data: numpy array
27
+ dtype: triton type, either pointer_type or scalar_type.
28
+ we don't store block_type here because the shape information is already available in the data field
29
+ attr: a dictionary of attributes
30
+ '''
31
+ data: np.array
32
+ dtype: tl.dtype
33
+ attr: Dict = dataclasses.field(default_factory=dict)
30
34
 
31
35
  def __bool__(self):
32
36
  return bool(self.data.all())
@@ -73,17 +77,19 @@ class BlockPointerHandle:
73
77
  class TensorDescHandle:
74
78
 
75
79
  def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
76
- block_shape: List[int]):
80
+ block_shape: List[int], padding):
77
81
  self.base = base
78
82
  self.ndim = len(shape)
79
83
  self.shape = shape
80
84
  self.strides = strides
81
85
  self.block_shape = block_shape
86
+ self.padding = padding
82
87
 
83
88
  def validate(self):
84
89
  assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
85
90
  assert len(self.strides) == self.ndim
86
91
  assert len(self.block_shape) == self.ndim
92
+ assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
87
93
 
88
94
  for stride in self.strides[:-1]:
89
95
  assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
@@ -103,6 +109,7 @@ class TensorDescHandle:
103
109
  off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
104
110
  ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
105
111
  masks = masks & (0 <= off) & (off < self.shape[dim].data)
112
+ assert ptrs.dtype == np.uint64
106
113
  ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
107
114
  return ptrs, masks
108
115
 
@@ -114,7 +121,7 @@ class InterpreterOptions:
114
121
  sanitize_overflow: bool = True
115
122
  arch: str = None
116
123
  supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
117
- deprecated_fp8_dtypes: Tuple[str] = ()
124
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
118
125
  default_dot_input_precision: str = "tf32"
119
126
  allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
120
127
  max_num_imprecise_acc_default: int = 0
@@ -248,8 +255,8 @@ np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
248
255
  class ExtraFunctions:
249
256
 
250
257
  @staticmethod
251
- def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder):
252
- return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
258
+ def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
259
+ return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
253
260
 
254
261
 
255
262
  class InterpreterBuilder:
@@ -306,6 +313,9 @@ class InterpreterBuilder:
306
313
  def get_double_ty(self):
307
314
  return tl.float64
308
315
 
316
+ def get_int1_ty(self):
317
+ return tl.int1
318
+
309
319
  def get_int8_ty(self):
310
320
  return tl.int8
311
321
 
@@ -587,11 +597,18 @@ class InterpreterBuilder:
587
597
  b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
588
598
  return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
589
599
 
590
- def create_make_range(self, start, stop):
600
+ def create_make_range(self, ret_ty, start, stop):
591
601
  return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
592
602
 
593
- def create_histogram(self, data, bins):
594
- return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32)
603
+ def create_histogram(self, data, bins, mask):
604
+ if mask is None:
605
+ mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
606
+ # force all masked elements to zero
607
+ data = np.where(mask.data, data.data, np.zeros_like(data.data))
608
+ histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
609
+ # remove overcounted elements
610
+ histogram[0] -= np.logical_not(mask.data).sum()
611
+ return TensorHandle(histogram, tl.int32)
595
612
 
596
613
  def create_gather(self, src, indices, axis):
597
614
  return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
@@ -641,12 +658,16 @@ class InterpreterBuilder:
641
658
  # Triton only supports splitting the original tensor into two along the last axis
642
659
  return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
643
660
 
644
- def create_splat(self, arg, shape):
661
+ def create_splat(self, ret_ty, arg):
662
+ shape = ret_ty.shape
645
663
  if isinstance(arg.dtype, tl.block_type):
646
664
  return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
647
665
  else: # scalar
648
666
  return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
649
667
 
668
+ def create_unsplat(self, arg):
669
+ return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
670
+
650
671
  def create_atomic_cas(self, ptr, cmp, val, sem, scope):
651
672
  if sem not in self.ir_sem_to_interpreter_sem:
652
673
  raise ValueError(f"unsupported semantic {sem}")
@@ -709,14 +730,9 @@ class InterpreterBuilder:
709
730
  ret.offsets[i].data += offsets[i].data
710
731
  return ret
711
732
 
712
- def create_make_tensor_descriptor(
713
- self,
714
- base: TensorHandle,
715
- shape: List[TensorHandle],
716
- strides: List[TensorHandle],
717
- tensor_shape: List[int],
718
- ):
719
- desc = TensorDescHandle(base, shape, strides, tensor_shape)
733
+ def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
734
+ tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
735
+ desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
720
736
  desc.validate()
721
737
  return desc
722
738
 
@@ -724,7 +740,16 @@ class InterpreterBuilder:
724
740
  eviction_policy):
725
741
  assert isinstance(desc, TensorDescHandle)
726
742
  ptrs, mask = desc.materialize_pointers(indices)
727
- return self.create_masked_load(ptrs, mask, other=None, cache_modifier=cache_modifier,
743
+ dtype_tt = ptrs.get_element_ty()
744
+ dtype_np = _get_np_dtype(dtype_tt)
745
+ padding = desc.padding
746
+ if padding == _ir.PADDING_OPTION.PAD_ZERO:
747
+ other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
748
+ elif padding == _ir.PADDING_OPTION.PAD_NAN:
749
+ other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
750
+ else:
751
+ raise ValueError(f"unsupported padding {padding}")
752
+ return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
728
753
  eviction_policy=eviction_policy, is_volatile=False)
729
754
 
730
755
  def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
@@ -753,15 +778,18 @@ class InterpreterBuilder:
753
778
  np_type = _get_np_dtype(type)
754
779
  if "int" in np_type.name:
755
780
  return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
781
+ elif np_type == np.bool_:
782
+ return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
756
783
  else:
757
784
  raise TypeError(f"unsupported type {type}")
758
785
 
759
786
 
760
787
  def _patch_attr(obj, name, member, builder):
788
+ semantic = TritonSemantic(builder)
761
789
  new_member = lambda *args, member=member, **kwargs: (member(*args, **
762
790
  {k: v
763
791
  for k, v in kwargs.items()
764
- if k != "_builder"}, _builder=builder))
792
+ if k != "_semantic"}, _semantic=semantic))
765
793
  setattr(obj, name, new_member)
766
794
 
767
795
 
@@ -822,12 +850,10 @@ class ReduceScanOpInterface:
822
850
 
823
851
  def apply(self, input):
824
852
  if not isinstance(input, tuple):
825
- input = (input, )
853
+ return self.apply((input, ))[0]
826
854
  self.check_tensor(input)
827
- return self.apply_impl(input)
828
-
829
- def apply_impl(self, input):
830
- raise NotImplementedError("apply_impl not implemented")
855
+ ret = self.apply_impl(input)
856
+ return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
831
857
 
832
858
 
833
859
  class ReduceOps(ReduceScanOpInterface):
@@ -887,7 +913,7 @@ class ReduceOps(ReduceScanOpInterface):
887
913
  # Take a scalar
888
914
  data = data.item()
889
915
  ret.append(self.to_tensor(data, input[i].dtype))
890
- return ret[0] if len(ret) == 1 else tuple(ret)
916
+ return ret
891
917
 
892
918
  def min_max(self, input, val_reduce_op, idx_reduce_op=None):
893
919
  # If input is a tuple, it must be (val, index), and we only take val
@@ -916,9 +942,9 @@ class ReduceOps(ReduceScanOpInterface):
916
942
  elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
917
943
  return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
918
944
  elif self.combine_fn == tl.standard._elementwise_max:
919
- return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None)
945
+ return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
920
946
  elif self.combine_fn == tl.standard._elementwise_min:
921
- return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None)
947
+ return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
922
948
  elif self.combine_fn == tl.standard._sum_combine:
923
949
  return self.sum(input[0])
924
950
  else:
@@ -985,7 +1011,7 @@ class ScanOps(ReduceScanOpInterface):
985
1011
  if self.reverse:
986
1012
  for arg in ret:
987
1013
  arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
988
- return len(ret) == 1 and ret[0] or tuple(ret)
1014
+ return ret
989
1015
 
990
1016
 
991
1017
  def _patch_reduce_scan():
@@ -1092,7 +1118,7 @@ def _patch_lang(fn):
1092
1118
  _patch_builtin(lang.math, interpreter_builder)
1093
1119
  _patch_lang_tensor(lang.tensor)
1094
1120
  _patch_lang_core(lang)
1095
- _patch_builtin(tl.core._experimental_tensor_descriptor_base, interpreter_builder)
1121
+ _patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder)
1096
1122
 
1097
1123
 
1098
1124
  def _tuple_create(arg, contents):
@@ -1107,7 +1133,7 @@ def _tuple_create(arg, contents):
1107
1133
  # TODO: wrap everything in triton tensors
1108
1134
  def _implicit_cvt(arg):
1109
1135
  if isinstance(arg, int):
1110
- ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
1136
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
1111
1137
  dtype = np.int32
1112
1138
  if -2**31 <= arg < 2**31:
1113
1139
  dtype = np.int32
@@ -1122,15 +1148,25 @@ def _implicit_cvt(arg):
1122
1148
  handle = TensorHandle(np.array([arg], dtype=dtype), ty)
1123
1149
  return tl.tensor(handle, ty)
1124
1150
  if hasattr(arg, "data_ptr"):
1125
- ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
1151
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
1126
1152
  handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
1127
1153
  return tl.tensor(handle, ty)
1128
1154
  elif isinstance(arg, tuple):
1129
1155
  return _tuple_create(arg, map(_implicit_cvt, arg))
1156
+ elif isinstance(arg, TensorDescriptor):
1157
+ strides = [_implicit_cvt(s) for s in arg.strides]
1158
+ assert arg.strides[-1] == 1
1159
+ strides[-1] = tl.constexpr(1)
1160
+ semantic = TritonSemantic(InterpreterBuilder())
1161
+ return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
1162
+ shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
1163
+ block_shape=[tl.constexpr(b)
1164
+ for b in arg.block_shape], padding_option=arg.padding)
1130
1165
  return arg
1131
1166
 
1132
1167
 
1133
1168
  interpreter_builder = InterpreterBuilder()
1169
+ interpreter_semantic = TritonSemantic(interpreter_builder)
1134
1170
 
1135
1171
 
1136
1172
  def _unwrap_tensor(t):
@@ -1162,6 +1198,14 @@ class GridExecutor:
1162
1198
  def _to_cpu(arg):
1163
1199
  if isinstance(arg, tuple):
1164
1200
  return _tuple_create(arg, map(_to_cpu, arg))
1201
+ elif isinstance(arg, TensorDescriptor):
1202
+ return TensorDescriptor(
1203
+ _to_cpu(arg.base),
1204
+ arg.shape,
1205
+ arg.strides,
1206
+ arg.block_shape,
1207
+ arg.padding,
1208
+ )
1165
1209
  elif not hasattr(arg, "data_ptr"):
1166
1210
  return arg
1167
1211
 
@@ -1195,6 +1239,8 @@ class GridExecutor:
1195
1239
  elif isinstance(arg_dev, tuple):
1196
1240
  for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
1197
1241
  _from_cpu(arg_dev, arg_hst)
1242
+ elif isinstance(arg_dev, TensorDescriptor):
1243
+ _from_cpu(arg_dev.base, arg_hst.base)
1198
1244
 
1199
1245
  for arg_dev, arg_hst in zip(args_dev, args_hst):
1200
1246
  _from_cpu(arg_dev, arg_hst)
@@ -1235,6 +1281,8 @@ class GridExecutor:
1235
1281
  interpreter_builder.set_grid_idx(x, y, z)
1236
1282
  self.fn(**args)
1237
1283
  except Exception as e:
1284
+ if triton.knobs.compilation.front_end_debugging:
1285
+ raise
1238
1286
  raise InterpreterError(repr(e)) from e
1239
1287
  # copy arguments back to propagate side-effects
1240
1288
  self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
@@ -1249,14 +1297,10 @@ class ASTTransformer(ast.NodeTransformer):
1249
1297
  if len(names) > 1:
1250
1298
  raise ValueError("Multiple assignments are not supported")
1251
1299
  # Modify the assignment x = value to
1252
- # triton.language.semantic.to_tensor(value, interpreter_builder, False)
1300
+ # interpreter_semantic.to_tensor(value, False)
1253
1301
  node.value = ast.Call(
1254
- func=ast.Attribute(
1255
- value=ast.Attribute(
1256
- value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()),
1257
- attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()),
1258
- args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()),
1259
- ast.Constant(value=False)], keywords=[])
1302
+ func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
1303
+ ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
1260
1304
  return node
1261
1305
 
1262
1306
 
@@ -1331,11 +1375,12 @@ class FunctionRewriter:
1331
1375
 
1332
1376
  class InterpretedFunction:
1333
1377
  # Cache all rewritten functions
1334
- rewritten_fn = {}
1378
+ rewritten_fn: Dict[Callable, Callable] = {}
1335
1379
 
1336
1380
  def __init__(self, fn, **kwargs) -> None:
1337
1381
  self.fn = fn
1338
1382
  self.rewriter = FunctionRewriter(fn, **kwargs)
1383
+ self.kwargs = kwargs
1339
1384
 
1340
1385
  def run(*args, **kwargs):
1341
1386
  grid = kwargs["grid"]