triton-windows 3.3.1.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
triton/runtime/build.py CHANGED
@@ -1,14 +1,25 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
- import sysconfig
4
+ import hashlib
5
+ import importlib.util
6
+ import logging
3
7
  import os
4
8
  import shutil
5
9
  import subprocess
10
+ import sysconfig
11
+ import tempfile
12
+
13
+ from types import ModuleType
14
+
15
+ from .cache import get_cache_manager
16
+ from .. import knobs
6
17
 
7
18
  if os.name == "nt":
8
19
  from triton.windows_utils import find_msvc_winsdk, find_python
9
20
 
10
21
 
11
- @functools.cache
22
+ @functools.lru_cache
12
23
  def get_cc():
13
24
  cc = os.environ.get("CC")
14
25
  if cc is None:
@@ -30,6 +41,11 @@ def get_cc():
30
41
  return cc
31
42
 
32
43
 
44
+ def is_tcc(cc):
45
+ cc = os.path.basename(cc).lower()
46
+ return cc == "tcc" or cc == "tcc.exe"
47
+
48
+
33
49
  def is_msvc(cc):
34
50
  cc = os.path.basename(cc).lower()
35
51
  return cc == "cl" or cc == "cl.exe"
@@ -58,13 +74,18 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
58
74
  if not (os.name == "nt" and is_clang(cc)):
59
75
  # Clang does not support -fPIC on Windows
60
76
  cc_cmd += ["-fPIC"]
77
+ if is_tcc(cc):
78
+ cc_cmd += ["-D_Py_USE_GCC_BUILTIN_ATOMICS"]
61
79
  cc_cmd += [f'-l{lib}' for lib in libraries]
62
80
  cc_cmd += [f"-L{dir}" for dir in library_dirs]
63
81
  cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
64
82
  return cc_cmd
65
83
 
66
84
 
67
- def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
85
+ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
86
+ libraries: list[str]) -> str:
87
+ if impl := knobs.build.impl:
88
+ return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
68
89
  suffix = sysconfig.get_config_var('EXT_SUFFIX')
69
90
  so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
70
91
  # try to avoid setuptools if possible
@@ -73,24 +94,25 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
73
94
  if hasattr(sysconfig, 'get_default_scheme'):
74
95
  scheme = sysconfig.get_default_scheme()
75
96
  else:
76
- scheme = sysconfig._get_default_scheme()
97
+ scheme = sysconfig._get_default_scheme() # type: ignore
77
98
  # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
78
99
  # path changes to include 'local'. This change is required to use triton with system-wide python.
79
100
  if scheme == 'posix_local':
80
101
  scheme = 'posix_prefix'
81
102
  py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
82
- custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH'))
103
+ custom_backend_dirs = knobs.build.backend_dirs
104
+ # Don't append in place
83
105
  include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
84
106
  if os.name == "nt":
85
- library_dirs += find_python()
86
- # Link against Python stable ABI
87
- # libraries is modified in place
88
- if "python3" not in libraries:
89
- libraries += ["python3"]
107
+ library_dirs = library_dirs + find_python()
108
+ version = sysconfig.get_python_version().replace(".", "")
109
+ if sysconfig.get_config_var("Py_GIL_DISABLED"):
110
+ version += "t"
111
+ libraries = libraries + [f"python{version}"]
90
112
  if is_msvc(cc):
91
113
  _, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
92
- include_dirs += msvc_winsdk_inc_dirs
93
- library_dirs += msvc_winsdk_lib_dirs
114
+ include_dirs = include_dirs + msvc_winsdk_inc_dirs
115
+ library_dirs = library_dirs + msvc_winsdk_lib_dirs
94
116
  cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
95
117
 
96
118
  try:
@@ -100,3 +122,45 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
100
122
  raise e
101
123
 
102
124
  return so
125
+
126
+
127
+ @functools.lru_cache
128
+ def platform_key() -> str:
129
+ from platform import machine, system, architecture
130
+ return ",".join([machine(), system(), *architecture()])
131
+
132
+
133
+ def _load_module_from_path(name: str, path: str) -> ModuleType:
134
+ # Loading module with relative path may cause error
135
+ path = os.path.abspath(path)
136
+ spec = importlib.util.spec_from_file_location(name, path)
137
+ if not spec or not spec.loader:
138
+ raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
139
+ mod = importlib.util.module_from_spec(spec)
140
+ spec.loader.exec_module(mod)
141
+ return mod
142
+
143
+
144
+ def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
145
+ include_dirs: list[str] | None = None, libraries: list[str] | None = None) -> ModuleType:
146
+ key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
147
+ cache = get_cache_manager(key)
148
+ suffix = sysconfig.get_config_var("EXT_SUFFIX")
149
+ cache_path = cache.get_file(f"{name}{suffix}")
150
+
151
+ if cache_path is not None:
152
+ try:
153
+ return _load_module_from_path(name, cache_path)
154
+ except (RuntimeError, ImportError):
155
+ log = logging.getLogger(__name__)
156
+ log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
157
+
158
+ with tempfile.TemporaryDirectory() as tmpdir:
159
+ src_path = os.path.join(tmpdir, name + ".c")
160
+ with open(src_path, "w") as f:
161
+ f.write(src)
162
+ so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [])
163
+ with open(so, "rb") as f:
164
+ cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
165
+
166
+ return _load_module_from_path(name, cache_path)
triton/runtime/cache.py CHANGED
@@ -1,33 +1,17 @@
1
- import importlib
2
1
  import json
3
2
  import os
4
3
  import uuid
5
4
  from abc import ABC, abstractmethod
6
- from pathlib import Path
7
5
  from typing import Dict, List, Optional
8
6
  import base64
9
7
  import hashlib
10
8
 
11
-
12
- def get_home_dir():
13
- return os.getenv("TRITON_HOME", Path.home())
14
-
15
-
16
- def default_cache_dir():
17
- return os.path.join(get_home_dir(), ".triton", "cache")
18
-
19
-
20
- def default_override_dir():
21
- return os.path.join(get_home_dir(), ".triton", "override")
22
-
23
-
24
- def default_dump_dir():
25
- return os.path.join(get_home_dir(), ".triton", "dump")
9
+ from .. import knobs
26
10
 
27
11
 
28
12
  class CacheManager(ABC):
29
13
 
30
- def __init__(self, key):
14
+ def __init__(self, key, override=False, dump=False):
31
15
  pass
32
16
 
33
17
  @abstractmethod
@@ -53,16 +37,16 @@ class FileCacheManager(CacheManager):
53
37
  self.key = key
54
38
  self.lock_path = None
55
39
  if dump:
56
- self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir()
40
+ self.cache_dir = knobs.cache.dump_dir
57
41
  self.cache_dir = os.path.join(self.cache_dir, self.key)
58
42
  self.lock_path = os.path.join(self.cache_dir, "lock")
59
43
  os.makedirs(self.cache_dir, exist_ok=True)
60
44
  elif override:
61
- self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir()
45
+ self.cache_dir = knobs.cache.override_dir
62
46
  self.cache_dir = os.path.join(self.cache_dir, self.key)
63
47
  else:
64
48
  # create cache directory if it doesn't exist
65
- self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
49
+ self.cache_dir = knobs.cache.dir
66
50
  if self.cache_dir:
67
51
  self.cache_dir = os.path.join(self.cache_dir, self.key)
68
52
  self.lock_path = os.path.join(self.cache_dir, "lock")
@@ -166,10 +150,10 @@ class RedisRemoteCacheBackend(RemoteCacheBackend):
166
150
  def __init__(self, key):
167
151
  import redis
168
152
  self._key = key
169
- self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
153
+ self._key_fmt = knobs.cache.redis.key_format
170
154
  self._redis = redis.Redis(
171
- host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
172
- port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
155
+ host=knobs.cache.redis.host,
156
+ port=knobs.cache.redis.port,
173
157
  )
174
158
 
175
159
  def _get_key(self, filename: str) -> str:
@@ -187,10 +171,10 @@ class RemoteCacheManager(CacheManager):
187
171
 
188
172
  def __init__(self, key, override=False, dump=False):
189
173
  # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
190
- remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
191
- module_path, clz_nme = remote_cache_manager.split(":")
192
- module = importlib.import_module(module_path)
193
- remote_cache_cls = getattr(module, clz_nme)
174
+ remote_cache_cls = knobs.cache.remote_manager_class
175
+ if not remote_cache_cls:
176
+ raise RuntimeError(
177
+ "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
194
178
  self._backend = remote_cache_cls(key)
195
179
 
196
180
  self._override = override
@@ -260,37 +244,24 @@ class RemoteCacheManager(CacheManager):
260
244
  return self.put(grp_contents, grp_filename)
261
245
 
262
246
 
263
- __cache_cls = FileCacheManager
264
- __cache_cls_nme = "DEFAULT"
265
-
266
-
267
247
  def _base32(key):
268
248
  # Assume key is a hex string.
269
249
  return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
270
250
 
271
251
 
272
252
  def get_cache_manager(key) -> CacheManager:
273
- import os
274
-
275
- user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
276
- global __cache_cls
277
- global __cache_cls_nme
278
-
279
- if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
280
- module_path, clz_nme = user_cache_manager.split(":")
281
- module = importlib.import_module(module_path)
282
- __cache_cls = getattr(module, clz_nme)
283
- __cache_cls_nme = user_cache_manager
284
-
285
- return __cache_cls(_base32(key))
253
+ cls = knobs.cache.manager_class or FileCacheManager
254
+ return cls(_base32(key))
286
255
 
287
256
 
288
257
  def get_override_manager(key) -> CacheManager:
289
- return __cache_cls(_base32(key), override=True)
258
+ cls = knobs.cache.manager_class or FileCacheManager
259
+ return cls(_base32(key), override=True)
290
260
 
291
261
 
292
262
  def get_dump_manager(key) -> CacheManager:
293
- return __cache_cls(_base32(key), dump=True)
263
+ cls = knobs.cache.manager_class or FileCacheManager
264
+ return cls(_base32(key), dump=True)
294
265
 
295
266
 
296
267
  def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
triton/runtime/driver.py CHANGED
@@ -1,59 +1,62 @@
1
- from ..backends import backends
2
- from ..backends import DriverBase
1
+ from __future__ import annotations
3
2
 
3
+ from ..backends import backends, DriverBase
4
4
 
5
- def _create_driver():
6
- actives = [x.driver for x in backends.values() if x.driver.is_active()]
7
- if len(actives) != 1:
8
- raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
9
- return actives[0]()
5
+ from typing import Any, Callable, Generic, TypeVar, Union
10
6
 
11
7
 
12
- class LazyProxy:
8
+ def _create_driver() -> DriverBase:
9
+ active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
10
+ if len(active_drivers) != 1:
11
+ raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
12
+ return active_drivers[0]()
13
13
 
14
- def __init__(self, init_fn):
14
+
15
+ T = TypeVar("T")
16
+
17
+
18
+ class LazyProxy(Generic[T]):
19
+
20
+ def __init__(self, init_fn: Callable[[], T]) -> None:
15
21
  self._init_fn = init_fn
16
- self._obj = None
22
+ self._obj: Union[T, None] = None
17
23
 
18
- def _initialize_obj(self):
24
+ def _initialize_obj(self) -> T:
19
25
  if self._obj is None:
20
26
  self._obj = self._init_fn()
27
+ return self._obj
21
28
 
22
- def __getattr__(self, name):
23
- self._initialize_obj()
24
- return getattr(self._obj, name)
29
+ def __getattr__(self, name) -> Any:
30
+ return getattr(self._initialize_obj(), name)
25
31
 
26
- def __setattr__(self, name, value):
32
+ def __setattr__(self, name: str, value: Any) -> None:
27
33
  if name in ["_init_fn", "_obj"]:
28
34
  super().__setattr__(name, value)
29
35
  else:
30
- self._initialize_obj()
31
- setattr(self._obj, name, value)
36
+ setattr(self._initialize_obj(), name, value)
32
37
 
33
- def __delattr__(self, name):
34
- self._initialize_obj()
35
- delattr(self._obj, name)
38
+ def __delattr__(self, name: str) -> None:
39
+ delattr(self._initialize_obj(), name)
36
40
 
37
- def __repr__(self):
41
+ def __repr__(self) -> str:
38
42
  if self._obj is None:
39
43
  return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
40
44
  return repr(self._obj)
41
45
 
42
- def __str__(self):
43
- self._initialize_obj()
44
- return str(self._obj)
46
+ def __str__(self) -> str:
47
+ return str(self._initialize_obj())
45
48
 
46
49
 
47
50
  class DriverConfig:
48
51
 
49
- def __init__(self):
50
- self.default = LazyProxy(_create_driver)
51
- self.active = self.default
52
+ def __init__(self) -> None:
53
+ self.default: LazyProxy[DriverBase] = LazyProxy(_create_driver)
54
+ self.active: Union[LazyProxy[DriverBase], DriverBase] = self.default
52
55
 
53
- def set_active(self, driver: DriverBase):
56
+ def set_active(self, driver: DriverBase) -> None:
54
57
  self.active = driver
55
58
 
56
- def reset_active(self):
59
+ def reset_active(self) -> None:
57
60
  self.active = self.default
58
61
 
59
62
 
@@ -1,32 +1,36 @@
1
+ from __future__ import annotations
1
2
  import ast
2
3
  import textwrap
3
4
  import inspect
4
- from typing import Tuple, List
5
+ from typing import Tuple, List, Dict
5
6
 
6
7
  import math
7
8
  import numpy as np
8
9
 
9
10
  import triton
10
11
  import triton.language as tl
12
+ import dataclasses
11
13
  from dataclasses import dataclass
14
+
15
+ from triton.language.semantic import TritonSemantic
16
+ from triton.tools.tensor_descriptor import TensorDescriptor
12
17
  from .errors import InterpreterError
13
18
  from functools import partial
14
19
  from .._C.libtriton import interpreter as _interpreter
15
20
  from .._C.libtriton import ir as _ir
16
21
 
17
22
 
23
+ @dataclass
18
24
  class TensorHandle:
19
-
20
- def __init__(self, data, dtype):
21
- '''
22
- data: numpy array
23
- dtype: triton type, either pointer_type or scalar_type.
24
- we don't store block_type here because the shape information is already available in the data field
25
- attr: a dictionary of attributes
26
- '''
27
- self.data = data
28
- self.dtype = dtype
29
- self.attr = {}
25
+ '''
26
+ data: numpy array
27
+ dtype: triton type, either pointer_type or scalar_type.
28
+ we don't store block_type here because the shape information is already available in the data field
29
+ attr: a dictionary of attributes
30
+ '''
31
+ data: np.array
32
+ dtype: tl.dtype
33
+ attr: Dict = dataclasses.field(default_factory=dict)
30
34
 
31
35
  def __bool__(self):
32
36
  return bool(self.data.all())
@@ -103,6 +107,7 @@ class TensorDescHandle:
103
107
  off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
104
108
  ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
105
109
  masks = masks & (0 <= off) & (off < self.shape[dim].data)
110
+ assert ptrs.dtype == np.uint64
106
111
  ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
107
112
  return ptrs, masks
108
113
 
@@ -114,7 +119,7 @@ class InterpreterOptions:
114
119
  sanitize_overflow: bool = True
115
120
  arch: str = None
116
121
  supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
117
- deprecated_fp8_dtypes: Tuple[str] = ()
122
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
118
123
  default_dot_input_precision: str = "tf32"
119
124
  allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
120
125
  max_num_imprecise_acc_default: int = 0
@@ -248,8 +253,8 @@ np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
248
253
  class ExtraFunctions:
249
254
 
250
255
  @staticmethod
251
- def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder):
252
- return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
256
+ def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
257
+ return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
253
258
 
254
259
 
255
260
  class InterpreterBuilder:
@@ -306,6 +311,9 @@ class InterpreterBuilder:
306
311
  def get_double_ty(self):
307
312
  return tl.float64
308
313
 
314
+ def get_int1_ty(self):
315
+ return tl.int1
316
+
309
317
  def get_int8_ty(self):
310
318
  return tl.int8
311
319
 
@@ -587,11 +595,18 @@ class InterpreterBuilder:
587
595
  b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
588
596
  return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
589
597
 
590
- def create_make_range(self, start, stop):
598
+ def create_make_range(self, ret_ty, start, stop):
591
599
  return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
592
600
 
593
- def create_histogram(self, data, bins):
594
- return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32)
601
+ def create_histogram(self, data, bins, mask):
602
+ if mask is None:
603
+ mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
604
+ # force all masked elements to zero
605
+ data = np.where(mask.data, data.data, np.zeros_like(data.data))
606
+ histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
607
+ # remove overcounted elements
608
+ histogram[0] -= np.logical_not(mask.data).sum()
609
+ return TensorHandle(histogram, tl.int32)
595
610
 
596
611
  def create_gather(self, src, indices, axis):
597
612
  return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
@@ -641,7 +656,8 @@ class InterpreterBuilder:
641
656
  # Triton only supports splitting the original tensor into two along the last axis
642
657
  return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
643
658
 
644
- def create_splat(self, arg, shape):
659
+ def create_splat(self, ret_ty, arg):
660
+ shape = ret_ty.shape
645
661
  if isinstance(arg.dtype, tl.block_type):
646
662
  return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
647
663
  else: # scalar
@@ -715,6 +731,7 @@ class InterpreterBuilder:
715
731
  shape: List[TensorHandle],
716
732
  strides: List[TensorHandle],
717
733
  tensor_shape: List[int],
734
+ is_signed: bool,
718
735
  ):
719
736
  desc = TensorDescHandle(base, shape, strides, tensor_shape)
720
737
  desc.validate()
@@ -753,15 +770,18 @@ class InterpreterBuilder:
753
770
  np_type = _get_np_dtype(type)
754
771
  if "int" in np_type.name:
755
772
  return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
773
+ elif np_type == np.bool_:
774
+ return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
756
775
  else:
757
776
  raise TypeError(f"unsupported type {type}")
758
777
 
759
778
 
760
779
  def _patch_attr(obj, name, member, builder):
780
+ semantic = TritonSemantic(builder)
761
781
  new_member = lambda *args, member=member, **kwargs: (member(*args, **
762
782
  {k: v
763
783
  for k, v in kwargs.items()
764
- if k != "_builder"}, _builder=builder))
784
+ if k != "_semantic"}, _semantic=semantic))
765
785
  setattr(obj, name, new_member)
766
786
 
767
787
 
@@ -822,12 +842,10 @@ class ReduceScanOpInterface:
822
842
 
823
843
  def apply(self, input):
824
844
  if not isinstance(input, tuple):
825
- input = (input, )
845
+ return self.apply((input, ))[0]
826
846
  self.check_tensor(input)
827
- return self.apply_impl(input)
828
-
829
- def apply_impl(self, input):
830
- raise NotImplementedError("apply_impl not implemented")
847
+ ret = self.apply_impl(input)
848
+ return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
831
849
 
832
850
 
833
851
  class ReduceOps(ReduceScanOpInterface):
@@ -887,7 +905,7 @@ class ReduceOps(ReduceScanOpInterface):
887
905
  # Take a scalar
888
906
  data = data.item()
889
907
  ret.append(self.to_tensor(data, input[i].dtype))
890
- return ret[0] if len(ret) == 1 else tuple(ret)
908
+ return ret
891
909
 
892
910
  def min_max(self, input, val_reduce_op, idx_reduce_op=None):
893
911
  # If input is a tuple, it must be (val, index), and we only take val
@@ -985,7 +1003,7 @@ class ScanOps(ReduceScanOpInterface):
985
1003
  if self.reverse:
986
1004
  for arg in ret:
987
1005
  arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
988
- return len(ret) == 1 and ret[0] or tuple(ret)
1006
+ return ret
989
1007
 
990
1008
 
991
1009
  def _patch_reduce_scan():
@@ -1092,7 +1110,7 @@ def _patch_lang(fn):
1092
1110
  _patch_builtin(lang.math, interpreter_builder)
1093
1111
  _patch_lang_tensor(lang.tensor)
1094
1112
  _patch_lang_core(lang)
1095
- _patch_builtin(tl.core._experimental_tensor_descriptor_base, interpreter_builder)
1113
+ _patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder)
1096
1114
 
1097
1115
 
1098
1116
  def _tuple_create(arg, contents):
@@ -1127,10 +1145,22 @@ def _implicit_cvt(arg):
1127
1145
  return tl.tensor(handle, ty)
1128
1146
  elif isinstance(arg, tuple):
1129
1147
  return _tuple_create(arg, map(_implicit_cvt, arg))
1148
+ elif isinstance(arg, TensorDescriptor):
1149
+ strides = [_implicit_cvt(s) for s in arg.strides]
1150
+ assert arg.strides[-1] == 1
1151
+ strides[-1] = tl.constexpr(1)
1152
+ semantic = TritonSemantic(InterpreterBuilder())
1153
+ return semantic.make_tensor_descriptor(
1154
+ base=_implicit_cvt(arg.base),
1155
+ shape=[_implicit_cvt(s) for s in arg.shape],
1156
+ strides=strides,
1157
+ block_shape=[tl.constexpr(b) for b in arg.block_shape],
1158
+ )
1130
1159
  return arg
1131
1160
 
1132
1161
 
1133
1162
  interpreter_builder = InterpreterBuilder()
1163
+ interpreter_semantic = TritonSemantic(interpreter_builder)
1134
1164
 
1135
1165
 
1136
1166
  def _unwrap_tensor(t):
@@ -1162,6 +1192,13 @@ class GridExecutor:
1162
1192
  def _to_cpu(arg):
1163
1193
  if isinstance(arg, tuple):
1164
1194
  return _tuple_create(arg, map(_to_cpu, arg))
1195
+ elif isinstance(arg, TensorDescriptor):
1196
+ return TensorDescriptor(
1197
+ _to_cpu(arg.base),
1198
+ arg.shape,
1199
+ arg.strides,
1200
+ arg.block_shape,
1201
+ )
1165
1202
  elif not hasattr(arg, "data_ptr"):
1166
1203
  return arg
1167
1204
 
@@ -1195,6 +1232,8 @@ class GridExecutor:
1195
1232
  elif isinstance(arg_dev, tuple):
1196
1233
  for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
1197
1234
  _from_cpu(arg_dev, arg_hst)
1235
+ elif isinstance(arg_dev, TensorDescriptor):
1236
+ _from_cpu(arg_dev.base, arg_hst.base)
1198
1237
 
1199
1238
  for arg_dev, arg_hst in zip(args_dev, args_hst):
1200
1239
  _from_cpu(arg_dev, arg_hst)
@@ -1235,6 +1274,8 @@ class GridExecutor:
1235
1274
  interpreter_builder.set_grid_idx(x, y, z)
1236
1275
  self.fn(**args)
1237
1276
  except Exception as e:
1277
+ if triton.knobs.compilation.front_end_debugging:
1278
+ raise
1238
1279
  raise InterpreterError(repr(e)) from e
1239
1280
  # copy arguments back to propagate side-effects
1240
1281
  self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
@@ -1249,14 +1290,10 @@ class ASTTransformer(ast.NodeTransformer):
1249
1290
  if len(names) > 1:
1250
1291
  raise ValueError("Multiple assignments are not supported")
1251
1292
  # Modify the assignment x = value to
1252
- # triton.language.semantic.to_tensor(value, interpreter_builder, False)
1293
+ # interpreter_semantic.to_tensor(value, False)
1253
1294
  node.value = ast.Call(
1254
- func=ast.Attribute(
1255
- value=ast.Attribute(
1256
- value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()),
1257
- attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()),
1258
- args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()),
1259
- ast.Constant(value=False)], keywords=[])
1295
+ func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
1296
+ ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
1260
1297
  return node
1261
1298
 
1262
1299