triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/compiler/compiler.py
CHANGED
|
@@ -7,15 +7,15 @@ from ..backends.compiler import Language
|
|
|
7
7
|
from ..backends.compiler import BaseBackend, GPUTarget
|
|
8
8
|
from .. import __version__, knobs
|
|
9
9
|
from ..runtime.autotuner import OutOfResources
|
|
10
|
-
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
|
10
|
+
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key
|
|
11
11
|
from ..runtime.driver import driver
|
|
12
12
|
from ..tools.disasm import get_sass
|
|
13
13
|
from pathlib import Path
|
|
14
14
|
import re
|
|
15
15
|
import functools
|
|
16
16
|
import os
|
|
17
|
-
import sysconfig
|
|
18
17
|
import time
|
|
18
|
+
import copy
|
|
19
19
|
|
|
20
20
|
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
21
21
|
# and any following whitespace
|
|
@@ -64,12 +64,9 @@ class ASTSource:
|
|
|
64
64
|
assert isinstance(k, tuple)
|
|
65
65
|
self.constants[k] = v
|
|
66
66
|
self.attrs = attrs or dict()
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
for k in self.signature.keys():
|
|
71
|
-
if not isinstance(k, str):
|
|
72
|
-
raise TypeError("Signature keys must be string")
|
|
67
|
+
for k in self.signature.keys():
|
|
68
|
+
if not isinstance(k, str):
|
|
69
|
+
raise TypeError("Signature keys must be string")
|
|
73
70
|
|
|
74
71
|
def hash(self):
|
|
75
72
|
sorted_sig = [v for k, v in sorted(self.signature.items())]
|
|
@@ -78,7 +75,7 @@ class ASTSource:
|
|
|
78
75
|
key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
|
|
79
76
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
80
77
|
|
|
81
|
-
def make_ir(self, options, codegen_fns, module_map, context):
|
|
78
|
+
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
|
|
82
79
|
from .code_generator import ast_to_ttir
|
|
83
80
|
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
84
81
|
module_map=module_map)
|
|
@@ -117,7 +114,7 @@ class IRSource:
|
|
|
117
114
|
def hash(self):
|
|
118
115
|
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
|
|
119
116
|
|
|
120
|
-
def make_ir(self, options, codegen_fns, module_map, context):
|
|
117
|
+
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
|
|
121
118
|
self.module.context = context
|
|
122
119
|
return self.module
|
|
123
120
|
|
|
@@ -129,42 +126,6 @@ class IRSource:
|
|
|
129
126
|
return dict()
|
|
130
127
|
|
|
131
128
|
|
|
132
|
-
@functools.lru_cache()
|
|
133
|
-
def triton_key():
|
|
134
|
-
import pkgutil
|
|
135
|
-
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
136
|
-
contents = []
|
|
137
|
-
# frontend
|
|
138
|
-
with open(__file__, "rb") as f:
|
|
139
|
-
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
140
|
-
# compiler
|
|
141
|
-
path_prefixes = [
|
|
142
|
-
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
|
|
143
|
-
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
|
|
144
|
-
]
|
|
145
|
-
for path, prefix in path_prefixes:
|
|
146
|
-
for lib in pkgutil.walk_packages([path], prefix=prefix):
|
|
147
|
-
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
148
|
-
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
149
|
-
|
|
150
|
-
# backend
|
|
151
|
-
libtriton_hash = hashlib.sha256()
|
|
152
|
-
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
153
|
-
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
|
|
154
|
-
while True:
|
|
155
|
-
chunk = f.read(1024**2)
|
|
156
|
-
if not chunk:
|
|
157
|
-
break
|
|
158
|
-
libtriton_hash.update(chunk)
|
|
159
|
-
contents.append(libtriton_hash.hexdigest())
|
|
160
|
-
# language
|
|
161
|
-
language_path = os.path.join(TRITON_PATH, 'language')
|
|
162
|
-
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
|
|
163
|
-
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
164
|
-
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
165
|
-
return f'{__version__}' + '-'.join(contents)
|
|
166
|
-
|
|
167
|
-
|
|
168
129
|
@functools.lru_cache()
|
|
169
130
|
def max_shared_mem(device):
|
|
170
131
|
return driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
@@ -258,7 +219,7 @@ class CompileTimer:
|
|
|
258
219
|
)
|
|
259
220
|
|
|
260
221
|
|
|
261
|
-
def compile(src, target=None, options=None):
|
|
222
|
+
def compile(src, target=None, options=None, _env_vars=None):
|
|
262
223
|
compilation_listener = knobs.compilation.listener
|
|
263
224
|
if compilation_listener:
|
|
264
225
|
timer = CompileTimer()
|
|
@@ -277,8 +238,8 @@ def compile(src, target=None, options=None):
|
|
|
277
238
|
extra_options = src.parse_options()
|
|
278
239
|
options = backend.parse_options(dict(options or dict(), **extra_options))
|
|
279
240
|
# create cache manager
|
|
280
|
-
env_vars = get_cache_invalidating_env_vars()
|
|
281
|
-
key =
|
|
241
|
+
env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars
|
|
242
|
+
key = get_cache_key(src, backend, options, env_vars=env_vars)
|
|
282
243
|
hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
283
244
|
fn_cache_manager = get_cache_manager(hash)
|
|
284
245
|
# For dumping/overriding only hash the source as we want it to be independent of triton
|
|
@@ -336,7 +297,7 @@ def compile(src, target=None, options=None):
|
|
|
336
297
|
codegen_fns = backend.get_codegen_implementation(options)
|
|
337
298
|
module_map = backend.get_module_map()
|
|
338
299
|
try:
|
|
339
|
-
module = src.make_ir(options, codegen_fns, module_map, context)
|
|
300
|
+
module = src.make_ir(target, options, codegen_fns, module_map, context)
|
|
340
301
|
except Exception as e:
|
|
341
302
|
filter_traceback(e)
|
|
342
303
|
raise
|
|
@@ -371,6 +332,9 @@ def compile(src, target=None, options=None):
|
|
|
371
332
|
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
|
372
333
|
if fn_dump_manager is not None:
|
|
373
334
|
fn_dump_manager.put(next_module, ir_filename)
|
|
335
|
+
if ext == "cubin":
|
|
336
|
+
sass = get_sass(next_module)
|
|
337
|
+
fn_dump_manager.put(sass, file_name + ".sass")
|
|
374
338
|
# use an env variable to parse ir from file
|
|
375
339
|
if use_ir_loc == ext:
|
|
376
340
|
ir_full_name = fn_cache_manager.get_file(ir_filename)
|
|
@@ -440,6 +404,10 @@ class AsmDict(dict):
|
|
|
440
404
|
return value
|
|
441
405
|
|
|
442
406
|
|
|
407
|
+
def _raise_error(err, *args, **kwargs):
|
|
408
|
+
raise copy.deepcopy(err)
|
|
409
|
+
|
|
410
|
+
|
|
443
411
|
class CompiledKernel:
|
|
444
412
|
|
|
445
413
|
def __init__(self, src, metadata_group, hash):
|
|
@@ -464,51 +432,66 @@ class CompiledKernel:
|
|
|
464
432
|
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
|
|
465
433
|
for file in asm_files
|
|
466
434
|
})
|
|
435
|
+
self.metadata_group = metadata_group
|
|
467
436
|
self.kernel = self.asm[binary_ext]
|
|
468
437
|
# binaries are lazily initialized
|
|
469
438
|
# because it involves doing runtime things
|
|
470
439
|
# (e.g., checking amount of shared memory on current device)
|
|
471
440
|
self.module = None
|
|
472
441
|
self.function = None
|
|
442
|
+
self._run = None
|
|
473
443
|
|
|
474
444
|
def _init_handles(self):
|
|
475
445
|
if self.module is not None:
|
|
476
446
|
return
|
|
447
|
+
|
|
448
|
+
def raise_(err):
|
|
449
|
+
# clone the exception object so that the one saved in the closure
|
|
450
|
+
# of the partial function below doesn't get assigned a stack trace
|
|
451
|
+
# after the subsequent raise. otherwise, the CompiledKernel instance
|
|
452
|
+
# saved in the (global) kernel cache will keep references to all the
|
|
453
|
+
# locals in the traceback via the exception instance in the closure.
|
|
454
|
+
cloned_err = copy.deepcopy(err)
|
|
455
|
+
self._run = functools.partial(_raise_error, cloned_err)
|
|
456
|
+
raise err
|
|
457
|
+
|
|
477
458
|
device = driver.active.get_current_device()
|
|
478
459
|
# create launcher
|
|
479
|
-
self.
|
|
460
|
+
self._run = driver.active.launcher_cls(self.src, self.metadata)
|
|
480
461
|
# not enough shared memory to run the kernel
|
|
481
462
|
max_shared = max_shared_mem(device)
|
|
482
463
|
if self.metadata.shared > max_shared:
|
|
483
|
-
|
|
464
|
+
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
|
|
484
465
|
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
|
|
485
466
|
# Use blackwell max tmem size for now, this should be moved in device properties
|
|
486
467
|
max_tmem_size = 512 # tmem size in number of columns
|
|
487
468
|
if self.metadata.tmem_size > max_tmem_size:
|
|
488
|
-
|
|
469
|
+
raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
|
|
470
|
+
if knobs.runtime.kernel_load_start_hook is not None:
|
|
471
|
+
knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
|
|
489
472
|
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
|
|
490
473
|
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
|
|
491
474
|
self.name, self.kernel, self.metadata.shared, device)
|
|
492
475
|
warp_size = driver.active.get_current_target().warp_size
|
|
493
476
|
if self.metadata.num_warps * warp_size > self.n_max_threads:
|
|
494
|
-
|
|
477
|
+
raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
|
|
478
|
+
if knobs.runtime.kernel_load_end_hook is not None:
|
|
479
|
+
knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
|
|
495
480
|
|
|
496
|
-
|
|
497
|
-
|
|
481
|
+
@property
|
|
482
|
+
def run(self):
|
|
483
|
+
if self._run is None:
|
|
498
484
|
self._init_handles()
|
|
499
|
-
return
|
|
485
|
+
return self._run
|
|
500
486
|
|
|
501
487
|
def launch_metadata(self, grid, stream, *args):
|
|
502
488
|
if knobs.runtime.launch_enter_hook is None:
|
|
503
489
|
return None
|
|
490
|
+
self._init_handles()
|
|
504
491
|
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
|
|
505
492
|
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
|
|
506
493
|
return ret
|
|
507
|
-
arg_dict = {}
|
|
508
|
-
arg_idx = 0
|
|
509
|
-
for i, arg_name in enumerate(self.src.fn.arg_names):
|
|
510
|
-
arg_dict[arg_name] = args[arg_idx]
|
|
511
|
-
arg_idx += 1
|
|
494
|
+
arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
|
|
512
495
|
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
|
|
513
496
|
return ret
|
|
514
497
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
import triton
|
|
3
2
|
from triton.compiler.compiler import ASTSource
|
|
4
3
|
from triton.backends.compiler import Language
|
|
5
|
-
from triton.runtime.jit import JITFunction
|
|
4
|
+
from triton.runtime.jit import JITFunction, constexpr_function
|
|
6
5
|
from typing import TypeVar, Optional, Callable, Iterable, Union
|
|
7
6
|
from triton._C.libtriton import ir
|
|
8
7
|
|
|
9
8
|
T = TypeVar("T")
|
|
10
9
|
|
|
10
|
+
__all__ = ["constexpr_function", "jit"]
|
|
11
|
+
|
|
11
12
|
|
|
12
13
|
class GluonASTSource(ASTSource):
|
|
13
14
|
|
|
@@ -16,7 +17,7 @@ class GluonASTSource(ASTSource):
|
|
|
16
17
|
self.language = Language.GLUON
|
|
17
18
|
self.ext = "ttgir"
|
|
18
19
|
|
|
19
|
-
def make_ir(self, options, codegen_fns, module_map, context):
|
|
20
|
+
def make_ir(self, target, options, codegen_fns, module_map, context):
|
|
20
21
|
from triton.compiler.compiler import make_backend
|
|
21
22
|
from triton.compiler.code_generator import ast_to_ttir
|
|
22
23
|
|
|
@@ -24,14 +25,16 @@ class GluonASTSource(ASTSource):
|
|
|
24
25
|
module = builder.create_module()
|
|
25
26
|
|
|
26
27
|
# Assign module attributes eagerly, as they are needed to verify layouts
|
|
27
|
-
target = triton.runtime.driver.active.get_current_target()
|
|
28
28
|
backend = make_backend(target)
|
|
29
29
|
target = backend.get_target_name(options)
|
|
30
|
+
|
|
30
31
|
module.set_attr("ttg.target", builder.get_string_attr(target))
|
|
31
32
|
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
|
|
32
33
|
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
|
|
33
|
-
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(
|
|
34
|
-
|
|
34
|
+
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
|
|
35
|
+
|
|
36
|
+
is_cuda = options.backend_name == "cuda"
|
|
37
|
+
if is_cuda and options.maxnreg is not None:
|
|
35
38
|
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
|
|
36
39
|
|
|
37
40
|
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
@@ -1,18 +1,119 @@
|
|
|
1
|
-
from ._core import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
1
|
+
from ._core import (
|
|
2
|
+
base_value,
|
|
3
|
+
base_type,
|
|
4
|
+
block_type,
|
|
5
|
+
broadcast,
|
|
6
|
+
constexpr,
|
|
7
|
+
dtype,
|
|
8
|
+
void,
|
|
9
|
+
int1,
|
|
10
|
+
int8,
|
|
11
|
+
int16,
|
|
12
|
+
int32,
|
|
13
|
+
int64,
|
|
14
|
+
uint8,
|
|
15
|
+
uint16,
|
|
16
|
+
uint32,
|
|
17
|
+
uint64,
|
|
18
|
+
float8e5,
|
|
19
|
+
float8e5b16,
|
|
20
|
+
float8e4nv,
|
|
21
|
+
float8e4b8,
|
|
22
|
+
float8e4b15,
|
|
23
|
+
float16,
|
|
24
|
+
bfloat16,
|
|
25
|
+
float32,
|
|
26
|
+
float64,
|
|
27
|
+
pointer_type,
|
|
28
|
+
shared_memory_descriptor,
|
|
29
|
+
tensor,
|
|
30
|
+
tuple,
|
|
31
|
+
tuple_type,
|
|
32
|
+
_unwrap_if_constexpr,
|
|
33
|
+
# API Functions
|
|
34
|
+
allocate_shared_memory,
|
|
35
|
+
arange,
|
|
36
|
+
associative_scan,
|
|
37
|
+
atomic_add,
|
|
38
|
+
atomic_and,
|
|
39
|
+
atomic_cas,
|
|
40
|
+
atomic_max,
|
|
41
|
+
atomic_min,
|
|
42
|
+
atomic_or,
|
|
43
|
+
atomic_xchg,
|
|
44
|
+
atomic_xor,
|
|
45
|
+
convert_layout,
|
|
46
|
+
device_assert,
|
|
47
|
+
expand_dims,
|
|
48
|
+
full,
|
|
49
|
+
histogram,
|
|
50
|
+
inline_asm_elementwise,
|
|
51
|
+
join,
|
|
52
|
+
load,
|
|
53
|
+
map_elementwise,
|
|
54
|
+
max_constancy,
|
|
55
|
+
max_contiguous,
|
|
56
|
+
maximum,
|
|
57
|
+
minimum,
|
|
58
|
+
multiple_of,
|
|
59
|
+
num_programs,
|
|
60
|
+
permute,
|
|
61
|
+
program_id,
|
|
62
|
+
reduce,
|
|
63
|
+
reshape,
|
|
64
|
+
set_auto_layout,
|
|
65
|
+
split,
|
|
66
|
+
static_assert,
|
|
67
|
+
static_print,
|
|
68
|
+
static_range,
|
|
69
|
+
store,
|
|
70
|
+
thread_barrier,
|
|
71
|
+
to_tensor,
|
|
72
|
+
warp_specialize,
|
|
73
|
+
where,
|
|
74
|
+
)
|
|
75
|
+
from ._layouts import (
|
|
76
|
+
AutoLayout,
|
|
77
|
+
BlockedLayout,
|
|
78
|
+
SliceLayout,
|
|
79
|
+
DistributedLinearLayout,
|
|
80
|
+
DotOperandLayout,
|
|
81
|
+
NVMMADistributedLayout,
|
|
82
|
+
NVMMASharedLayout,
|
|
83
|
+
SwizzledSharedLayout,
|
|
84
|
+
PaddedSharedLayout,
|
|
85
|
+
)
|
|
86
|
+
from ._math import (
|
|
87
|
+
umulhi,
|
|
88
|
+
exp,
|
|
89
|
+
exp2,
|
|
90
|
+
fma,
|
|
91
|
+
log,
|
|
92
|
+
log2,
|
|
93
|
+
cos,
|
|
94
|
+
rsqrt,
|
|
95
|
+
sin,
|
|
96
|
+
sqrt,
|
|
97
|
+
sqrt_rn,
|
|
98
|
+
abs,
|
|
99
|
+
fdiv,
|
|
100
|
+
div_rn,
|
|
101
|
+
erf,
|
|
102
|
+
floor,
|
|
103
|
+
ceil,
|
|
104
|
+
)
|
|
105
|
+
from ._standard import (
|
|
106
|
+
cdiv,
|
|
107
|
+
full_like,
|
|
108
|
+
max,
|
|
109
|
+
min,
|
|
110
|
+
reduce_or,
|
|
111
|
+
sum,
|
|
112
|
+
xor_sum,
|
|
113
|
+
zeros,
|
|
114
|
+
zeros_like,
|
|
115
|
+
)
|
|
9
116
|
|
|
10
117
|
from . import nvidia
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
*__core_all,
|
|
14
|
-
*__layouts_all,
|
|
15
|
-
*__math_all,
|
|
16
|
-
*__standard_all,
|
|
17
|
-
"nvidia",
|
|
18
|
-
]
|
|
118
|
+
from . import amd
|
|
119
|
+
from . import extra
|