triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
triton/compiler/compiler.py
CHANGED
|
@@ -3,7 +3,7 @@ import hashlib
|
|
|
3
3
|
import json
|
|
4
4
|
from .._C.libtriton import get_cache_invalidating_env_vars, ir
|
|
5
5
|
from ..backends import backends
|
|
6
|
-
from ..backends.compiler import GPUTarget
|
|
6
|
+
from ..backends.compiler import GPUTarget
|
|
7
7
|
from .. import __version__
|
|
8
8
|
from ..runtime.autotuner import OutOfResources
|
|
9
9
|
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
|
@@ -15,6 +15,7 @@ from pathlib import Path
|
|
|
15
15
|
import re
|
|
16
16
|
import functools
|
|
17
17
|
import os
|
|
18
|
+
import sysconfig
|
|
18
19
|
|
|
19
20
|
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
20
21
|
# and any following whitespace
|
|
@@ -24,19 +25,13 @@ import os
|
|
|
24
25
|
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
|
25
26
|
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
|
26
27
|
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
|
|
27
|
-
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
|
|
28
28
|
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
|
29
29
|
prototype_pattern = {
|
|
30
|
-
"ttir": mlir_prototype_pattern,
|
|
31
|
-
"ttgir": mlir_prototype_pattern,
|
|
32
30
|
"ptx": ptx_prototype_pattern,
|
|
33
31
|
}
|
|
34
32
|
|
|
35
|
-
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
|
|
36
33
|
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
|
37
34
|
arg_type_pattern = {
|
|
38
|
-
"ttir": mlir_arg_type_pattern,
|
|
39
|
-
"ttgir": mlir_arg_type_pattern,
|
|
40
35
|
"ptx": ptx_arg_type_pattern,
|
|
41
36
|
}
|
|
42
37
|
|
|
@@ -54,46 +49,32 @@ def convert_type_repr(x):
|
|
|
54
49
|
return x
|
|
55
50
|
|
|
56
51
|
|
|
57
|
-
def _get_num_warps_from_ir_str(src: str):
|
|
58
|
-
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
|
59
|
-
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
|
|
60
|
-
# e.g. someone has an instruction (not module) attribute named "num-warps".
|
|
61
|
-
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
|
62
|
-
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
|
63
|
-
num_warps = int(num_warps_matches[0])
|
|
64
|
-
return num_warps
|
|
65
|
-
|
|
66
|
-
|
|
67
52
|
class ASTSource:
|
|
68
53
|
|
|
69
|
-
def __init__(self, fn, signature,
|
|
54
|
+
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
|
|
70
55
|
self.fn = fn
|
|
71
56
|
self.ext = "ttir"
|
|
72
57
|
self.name = fn.__name__
|
|
73
58
|
self.signature = signature
|
|
74
|
-
self.constants =
|
|
75
|
-
|
|
59
|
+
self.constants = dict()
|
|
60
|
+
if constexprs is not None:
|
|
61
|
+
for k, v in constexprs.items():
|
|
62
|
+
k = (fn.arg_names.index(k), ) if isinstance(k, str) else k
|
|
63
|
+
assert isinstance(k, tuple)
|
|
64
|
+
self.constants[k] = v
|
|
65
|
+
self.attrs = attrs or dict()
|
|
76
66
|
if isinstance(self.signature, str):
|
|
77
67
|
self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
|
|
78
|
-
else:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if self.constants is None:
|
|
83
|
-
self.constants = {}
|
|
84
|
-
else:
|
|
85
|
-
for k in self.constants.keys():
|
|
86
|
-
if not isinstance(k, str):
|
|
87
|
-
raise TypeError("Constants keys must be string")
|
|
88
|
-
if self.attrs is None:
|
|
89
|
-
self.attrs = AttrsDescriptor()
|
|
68
|
+
# else:
|
|
69
|
+
# for k in self.signature.keys():
|
|
70
|
+
# if not isinstance(k, str):
|
|
71
|
+
# raise TypeError("Signature keys must be string")
|
|
90
72
|
|
|
91
73
|
def hash(self):
|
|
92
74
|
sorted_sig = [v for k, v in sorted(self.signature.items())]
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
|
|
75
|
+
get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x)
|
|
76
|
+
constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())])
|
|
77
|
+
key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
|
|
97
78
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
98
79
|
|
|
99
80
|
def make_ir(self, options, codegen_fns, module_map, context):
|
|
@@ -106,28 +87,42 @@ class ASTSource:
|
|
|
106
87
|
|
|
107
88
|
class IRSource:
|
|
108
89
|
|
|
109
|
-
def __init__(self, path):
|
|
90
|
+
def __init__(self, path, context, backend):
|
|
110
91
|
self.path = path
|
|
111
92
|
path = Path(path)
|
|
112
93
|
self.ext = path.suffix[1:]
|
|
113
94
|
self.src = path.read_text()
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
95
|
+
ir.load_dialects(context)
|
|
96
|
+
backend.load_dialects(context)
|
|
97
|
+
|
|
98
|
+
# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
|
|
99
|
+
# TODO - replace with a proper parser
|
|
100
|
+
if self.ext == "ptx":
|
|
101
|
+
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
|
|
102
|
+
self.name = match.group(1)
|
|
103
|
+
signature = match.group(2)
|
|
104
|
+
types = re.findall(arg_type_pattern[self.ext], signature)
|
|
105
|
+
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
|
|
106
|
+
else:
|
|
107
|
+
self.module = ir.parse_mlir_module(self.path, context)
|
|
108
|
+
fn_name = self.module.get_entry_func_name()
|
|
109
|
+
self.name = "@" + fn_name
|
|
110
|
+
funcOp = self.module.get_function(fn_name)
|
|
111
|
+
func_ty = self.module.get_function_signature(funcOp)
|
|
112
|
+
self.signature = {k: ty for k, ty in enumerate(func_ty)}
|
|
119
113
|
|
|
120
114
|
def hash(self):
|
|
121
115
|
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
|
|
122
116
|
|
|
123
117
|
def make_ir(self, options, codegen_fns, module_map, context):
|
|
124
|
-
module =
|
|
125
|
-
module
|
|
126
|
-
return module
|
|
118
|
+
self.module.context = context
|
|
119
|
+
return self.module
|
|
127
120
|
|
|
128
121
|
def parse_options(self):
|
|
129
122
|
if self.ext == "ttgir":
|
|
130
|
-
|
|
123
|
+
num_warps = self.module.get_int_attr("ttg.num-warps")
|
|
124
|
+
assert num_warps is not None, "Unable to parse ttg.num-warps attribute"
|
|
125
|
+
return {'num_warps': num_warps}
|
|
131
126
|
return dict()
|
|
132
127
|
|
|
133
128
|
|
|
@@ -151,11 +146,8 @@ def triton_key():
|
|
|
151
146
|
|
|
152
147
|
# backend
|
|
153
148
|
libtriton_hash = hashlib.sha256()
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
else:
|
|
157
|
-
so_name = "libtriton.so"
|
|
158
|
-
with open(os.path.join(TRITON_PATH, f"_C/{so_name}"), "rb") as f:
|
|
149
|
+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
150
|
+
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
|
|
159
151
|
while True:
|
|
160
152
|
chunk = f.read(1024**2)
|
|
161
153
|
if not chunk:
|
|
@@ -175,9 +167,9 @@ def parse(full_name, ext, context):
|
|
|
175
167
|
module = ir.parse_mlir_module(full_name, context)
|
|
176
168
|
module.context = context
|
|
177
169
|
return module
|
|
178
|
-
if ext == "llir" or ext == "ptx":
|
|
170
|
+
if ext == "llir" or ext == "ptx" or ext == "amdgcn":
|
|
179
171
|
return Path(full_name).read_text()
|
|
180
|
-
if ext == "cubin":
|
|
172
|
+
if ext == "cubin" or ext == "hsaco":
|
|
181
173
|
return Path(full_name).read_bytes()
|
|
182
174
|
|
|
183
175
|
|
|
@@ -200,6 +192,7 @@ def filter_traceback(e: BaseException):
|
|
|
200
192
|
"/triton/compiler/code_generator.py",
|
|
201
193
|
"/ast.py",
|
|
202
194
|
]
|
|
195
|
+
BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES]
|
|
203
196
|
|
|
204
197
|
tb = e.__traceback__
|
|
205
198
|
frames = []
|
|
@@ -227,7 +220,9 @@ def compile(src, target=None, options=None):
|
|
|
227
220
|
# create backend
|
|
228
221
|
if ir_source:
|
|
229
222
|
assert isinstance(src, str), "source must be either AST or a filepath"
|
|
230
|
-
|
|
223
|
+
context = ir.context()
|
|
224
|
+
src = IRSource(src, context, backend)
|
|
225
|
+
|
|
231
226
|
extra_options = src.parse_options()
|
|
232
227
|
options = backend.parse_options(dict(options or dict(), **extra_options))
|
|
233
228
|
# create cache manager
|
|
@@ -239,6 +234,7 @@ def compile(src, target=None, options=None):
|
|
|
239
234
|
# core changes to make it easier to track kernels by hash.
|
|
240
235
|
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
|
241
236
|
enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
|
|
237
|
+
store_only_binary = os.environ.get("TRITON_STORE_BINARY_ONLY", "0") == "1"
|
|
242
238
|
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
|
|
243
239
|
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
|
|
244
240
|
# Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
|
|
@@ -252,7 +248,6 @@ def compile(src, target=None, options=None):
|
|
|
252
248
|
always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1"
|
|
253
249
|
if not always_compile and metadata_path is not None:
|
|
254
250
|
# cache hit!
|
|
255
|
-
metadata = json.loads(Path(metadata_path).read_text())
|
|
256
251
|
return CompiledKernel(src, metadata_group, hash)
|
|
257
252
|
# initialize metadata
|
|
258
253
|
metadata = {
|
|
@@ -261,6 +256,7 @@ def compile(src, target=None, options=None):
|
|
|
261
256
|
**options.__dict__,
|
|
262
257
|
**env_vars,
|
|
263
258
|
}
|
|
259
|
+
metadata["triton_version"] = __version__
|
|
264
260
|
# run compilation pipeline and populate metadata
|
|
265
261
|
stages = dict()
|
|
266
262
|
backend.add_stages(stages, options)
|
|
@@ -268,10 +264,15 @@ def compile(src, target=None, options=None):
|
|
|
268
264
|
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
|
|
269
265
|
if ir_source:
|
|
270
266
|
first_stage += 1
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
backend.load_dialects
|
|
274
|
-
|
|
267
|
+
|
|
268
|
+
# For IRSource, we have already grabbed the context + called both
|
|
269
|
+
# ir.load_dialects and backend.load_dialects.
|
|
270
|
+
if not isinstance(src, IRSource):
|
|
271
|
+
context = ir.context()
|
|
272
|
+
ir.load_dialects(context)
|
|
273
|
+
backend.load_dialects(context)
|
|
274
|
+
|
|
275
|
+
codegen_fns = backend.get_codegen_implementation(options)
|
|
275
276
|
module_map = backend.get_module_map()
|
|
276
277
|
try:
|
|
277
278
|
module = src.make_ir(options, codegen_fns, module_map, context)
|
|
@@ -285,7 +286,9 @@ def compile(src, target=None, options=None):
|
|
|
285
286
|
if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
|
|
286
287
|
print(f"\nOverriding kernel with file {full_name}")
|
|
287
288
|
next_module = parse(full_name, ext, context)
|
|
288
|
-
|
|
289
|
+
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
|
|
290
|
+
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")):
|
|
291
|
+
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
|
289
292
|
if fn_dump_manager is not None:
|
|
290
293
|
fn_dump_manager.put(next_module, ir_filename)
|
|
291
294
|
# use an env variable to parse ir from file
|
|
@@ -302,7 +305,13 @@ def compile(src, target=None, options=None):
|
|
|
302
305
|
# This is needed to safely finalize threads pool inside context: if current process forks before
|
|
303
306
|
# python GC deletes context object, thread pool in child process will be invalid, which could
|
|
304
307
|
# lead to child crash or hang.
|
|
305
|
-
|
|
308
|
+
#
|
|
309
|
+
# However disabling multithreading causes the code to hang if the ASAN pass is enabled
|
|
310
|
+
# this is likely due to the llvm-symbolizer forking a process
|
|
311
|
+
# TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
|
|
312
|
+
# multithreading in the MLIR context
|
|
313
|
+
if not os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
|
|
314
|
+
context.disable_multithreading()
|
|
306
315
|
# return handle to compiled kernel
|
|
307
316
|
return CompiledKernel(src, metadata_group, hash)
|
|
308
317
|
|
|
@@ -390,6 +399,11 @@ class CompiledKernel:
|
|
|
390
399
|
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
391
400
|
if self.metadata.shared > max_shared:
|
|
392
401
|
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
|
|
402
|
+
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
|
|
403
|
+
# Use blackwell max tmem size for now, this should be moved in device properties
|
|
404
|
+
max_tmem_size = 512 # tmem size in number of columns
|
|
405
|
+
if self.metadata.tmem_size > max_tmem_size:
|
|
406
|
+
raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
|
|
393
407
|
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
|
|
394
408
|
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
|
|
395
409
|
self.name, self.kernel, self.metadata.shared, device)
|
|
@@ -408,11 +422,8 @@ class CompiledKernel:
|
|
|
408
422
|
arg_dict = {}
|
|
409
423
|
arg_idx = 0
|
|
410
424
|
for i, arg_name in enumerate(self.src.fn.arg_names):
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
else:
|
|
414
|
-
arg_dict[arg_name] = args[arg_idx]
|
|
415
|
-
arg_idx += 1
|
|
425
|
+
arg_dict[arg_name] = args[arg_idx]
|
|
426
|
+
arg_idx += 1
|
|
416
427
|
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
|
|
417
428
|
return ret
|
|
418
429
|
|
triton/language/__init__.py
CHANGED
|
@@ -28,6 +28,9 @@ from .core import (
|
|
|
28
28
|
TRITON_MAX_TENSOR_NUMEL,
|
|
29
29
|
_experimental_descriptor_load,
|
|
30
30
|
_experimental_descriptor_store,
|
|
31
|
+
_experimental_make_tensor_descriptor,
|
|
32
|
+
_experimental_reinterpret_tensor_descriptor,
|
|
33
|
+
_experimental_tensor_descriptor,
|
|
31
34
|
add,
|
|
32
35
|
advance,
|
|
33
36
|
arange,
|
|
@@ -66,7 +69,7 @@ from .core import (
|
|
|
66
69
|
float8e5,
|
|
67
70
|
float8e5b16,
|
|
68
71
|
full,
|
|
69
|
-
|
|
72
|
+
gather,
|
|
70
73
|
histogram,
|
|
71
74
|
inline_asm_elementwise,
|
|
72
75
|
int1,
|
|
@@ -91,6 +94,7 @@ from .core import (
|
|
|
91
94
|
range,
|
|
92
95
|
reduce,
|
|
93
96
|
reshape,
|
|
97
|
+
slice,
|
|
94
98
|
split,
|
|
95
99
|
static_assert,
|
|
96
100
|
static_print,
|
|
@@ -98,6 +102,8 @@ from .core import (
|
|
|
98
102
|
store,
|
|
99
103
|
tensor,
|
|
100
104
|
trans,
|
|
105
|
+
tuple,
|
|
106
|
+
tuple_type,
|
|
101
107
|
uint16,
|
|
102
108
|
uint32,
|
|
103
109
|
uint64,
|
|
@@ -126,6 +132,9 @@ __all__ = [
|
|
|
126
132
|
"TRITON_MAX_TENSOR_NUMEL",
|
|
127
133
|
"_experimental_descriptor_load",
|
|
128
134
|
"_experimental_descriptor_store",
|
|
135
|
+
"_experimental_make_tensor_descriptor",
|
|
136
|
+
"_experimental_reinterpret_tensor_descriptor",
|
|
137
|
+
"_experimental_tensor_descriptor",
|
|
129
138
|
"abs",
|
|
130
139
|
"add",
|
|
131
140
|
"advance",
|
|
@@ -146,7 +155,6 @@ __all__ = [
|
|
|
146
155
|
"block_type",
|
|
147
156
|
"broadcast",
|
|
148
157
|
"broadcast_to",
|
|
149
|
-
"builtin",
|
|
150
158
|
"cat",
|
|
151
159
|
"cast",
|
|
152
160
|
"cdiv",
|
|
@@ -182,7 +190,7 @@ __all__ = [
|
|
|
182
190
|
"floor",
|
|
183
191
|
"fma",
|
|
184
192
|
"full",
|
|
185
|
-
"
|
|
193
|
+
"gather",
|
|
186
194
|
"histogram",
|
|
187
195
|
"inline_asm_elementwise",
|
|
188
196
|
"interleave",
|
|
@@ -191,7 +199,6 @@ __all__ = [
|
|
|
191
199
|
"int32",
|
|
192
200
|
"int64",
|
|
193
201
|
"int8",
|
|
194
|
-
"ir",
|
|
195
202
|
"join",
|
|
196
203
|
"load",
|
|
197
204
|
"log",
|
|
@@ -225,6 +232,7 @@ __all__ = [
|
|
|
225
232
|
"reduce",
|
|
226
233
|
"reshape",
|
|
227
234
|
"rsqrt",
|
|
235
|
+
"slice",
|
|
228
236
|
"sigmoid",
|
|
229
237
|
"sin",
|
|
230
238
|
"softmax",
|
|
@@ -240,7 +248,7 @@ __all__ = [
|
|
|
240
248
|
"swizzle2d",
|
|
241
249
|
"tensor",
|
|
242
250
|
"trans",
|
|
243
|
-
"
|
|
251
|
+
"tuple",
|
|
244
252
|
"uint16",
|
|
245
253
|
"uint32",
|
|
246
254
|
"uint64",
|
|
@@ -257,6 +265,12 @@ __all__ = [
|
|
|
257
265
|
|
|
258
266
|
|
|
259
267
|
def str_to_ty(name):
|
|
268
|
+
from builtins import tuple
|
|
269
|
+
|
|
270
|
+
if isinstance(name, tuple):
|
|
271
|
+
fields = type(name).__dict__.get("_fields", None)
|
|
272
|
+
return tuple_type([str_to_ty(x) for x in name], fields)
|
|
273
|
+
|
|
260
274
|
if name[0] == "*":
|
|
261
275
|
name = name[1:]
|
|
262
276
|
const = False
|
|
@@ -269,6 +283,9 @@ def str_to_ty(name):
|
|
|
269
283
|
if name == "nvTmaDesc":
|
|
270
284
|
return nv_tma_desc_type()
|
|
271
285
|
|
|
286
|
+
if name == "constexpr":
|
|
287
|
+
return constexpr
|
|
288
|
+
|
|
272
289
|
tys = {
|
|
273
290
|
"fp8e4nv": float8e4nv,
|
|
274
291
|
"fp8e4b8": float8e4b8,
|